diff --git a/BUILD.bazel b/BUILD.bazel index baa9cec726b..e54d106914b 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -327,7 +327,6 @@ xnnpack_cc_library( ":microparams", ":mutex", ":xnnpack_h", - "@FP16", ], ) @@ -428,7 +427,6 @@ xnnpack_cc_library( ":packing", ":prod_microkernels", ":xnnpack_h", - "@FP16", ] + select({ ":cpuinfo_enabled": ["@cpuinfo"], "//conditions:default": [], @@ -449,13 +447,19 @@ xnnpack_cc_library( ], ) +xnnpack_cc_library( + name = "fp16", + hdrs = ["src/xnnpack/fp16.h"], + compatible_with = [], +) + xnnpack_cc_library( name = "math", hdrs = ["src/xnnpack/math.h"], deps = [ ":common", ":config_hdrs", - "@FP16", + ":fp16", ], ) @@ -598,7 +602,6 @@ xnnpack_cc_library( ":common", ":math", ":microparams", - "@FP16", ], ) @@ -777,7 +780,6 @@ xnnpack_cc_library( ":microparams", ":operator_h", ":xnnpack_h", - "@FP16", "@FXdiv", ], ) @@ -796,7 +798,6 @@ xnnpack_cxx_library( ":params", ":unaligned", ":xnnpack_h", - "@FP16", ] + xnnpack_if_kleidiai_enabled([ "@KleidiAI//kai/ukernels/matmul", "@KleidiAI//kai/ukernels/matmul:rhs_pack_kxn_qsi4cxp_qsu4cxs1s0", @@ -931,6 +932,7 @@ xnnpack_cc_library( ":allocator", ":cache", ":common", + ":fp16", ":indirection", ":logging", ":math", @@ -946,7 +948,6 @@ xnnpack_cc_library( ":params", ":quantization", ":xnnpack_h", - "@FP16", "@pthreadpool", ] + select({ "//conditions:default": [], @@ -969,6 +970,7 @@ xnnpack_cc_library( ":cache", ":common", ":config_hdrs", + ":fp16", ":hardware_config", ":internal", ":logging", @@ -983,7 +985,6 @@ xnnpack_cc_library( ":params", ":requantization", ":xnnpack_h", - "@FP16", "@pthreadpool", ] + xnnpack_slinky_deps(), ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 948c69d8243..1fdaf28f02a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -304,16 +304,6 @@ IF(NOT XNNPACK_USE_SYSTEM_LIBS) SET(CPUINFO_SOURCE_DIR "${CMAKE_BINARY_DIR}/cpuinfo-source" CACHE STRING "cpuinfo source directory") ENDIF() - IF(NOT DEFINED FP16_SOURCE_DIR) - MESSAGE(STATUS "Downloading FP16 to ${CMAKE_BINARY_DIR}/FP16-source (define FP16_SOURCE_DIR to avoid it)") - CONFIGURE_FILE(cmake/DownloadFP16.cmake "${CMAKE_BINARY_DIR}/FP16-download/CMakeLists.txt") - EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . - WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/FP16-download") - EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . - WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/FP16-download") - SET(FP16_SOURCE_DIR "${CMAKE_BINARY_DIR}/FP16-source" CACHE STRING "FP16 source directory") - ENDIF() - IF(NOT DEFINED FXDIV_SOURCE_DIR) MESSAGE(STATUS "Downloading FXdiv to ${CMAKE_BINARY_DIR}/FXdiv-source (define FXDIV_SOURCE_DIR to avoid it)") CONFIGURE_FILE(cmake/DownloadFXdiv.cmake "${CMAKE_BINARY_DIR}/FXdiv-download/CMakeLists.txt") @@ -1116,42 +1106,7 @@ IF(XNNPACK_BUILD_LIBRARY) TARGET_LINK_LIBRARIES(XNNPACK PRIVATE fxdiv) ENDIF() -# ---[ Configure FP16 -IF(NOT TARGET fp16) - IF(NOT XNNPACK_USE_SYSTEM_LIBS) - SET(FP16_BUILD_TESTS OFF CACHE BOOL "") - SET(FP16_BUILD_BENCHMARKS OFF CACHE BOOL "") - ADD_SUBDIRECTORY( - "${FP16_SOURCE_DIR}" - "${CMAKE_BINARY_DIR}/FP16") - ELSE() - FIND_FILE(FP16_HDR fp16.h PATH_SUFFIXES include PATHS "${FP16_SOURCE_DIR}") - IF(NOT FP16_HDR) - MESSAGE(FATAL_ERROR "Cannot find fp16") - ENDIF() - ADD_LIBRARY(fp16 STATIC "${FP16_HDR}") - TARGET_INCLUDE_DIRECTORIES(fp16 INTERFACE "${FP16_SOURCE_DIR}/include") - SET_PROPERTY(TARGET fp16 PROPERTY LINKER_LANGUAGE C) - ENDIF() -ENDIF() -IF(XNNPACK_BUILD_ALL_MICROKERNELS) - TARGET_LINK_LIBRARIES(microkernels-all PRIVATE fp16) -ENDIF() -TARGET_LINK_LIBRARIES(microkernels-prod PRIVATE fp16) -TARGET_LINK_LIBRARIES(microparams-init PRIVATE fp16) -TARGET_LINK_LIBRARIES(packing PRIVATE fp16) -TARGET_LINK_LIBRARIES(indirection PRIVATE fp16) -TARGET_LINK_LIBRARIES(memory PRIVATE fp16) -TARGET_LINK_LIBRARIES(normalization PRIVATE fp16) -TARGET_LINK_LIBRARIES(microkernel-utils PRIVATE fp16) -TARGET_LINK_LIBRARIES(cache PRIVATE fp16) -TARGET_LINK_LIBRARIES(operator-utils PRIVATE fp16) IF(XNNPACK_BUILD_LIBRARY) - TARGET_LINK_LIBRARIES(subgraph PRIVATE fp16) - TARGET_LINK_LIBRARIES(operators PRIVATE fp16) - TARGET_LINK_LIBRARIES(operator-run PRIVATE fp16) - - TARGET_LINK_LIBRARIES(XNNPACK PRIVATE fp16) INSTALL(TARGETS XNNPACK LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} @@ -1212,7 +1167,7 @@ IF(XNNPACK_BUILD_TESTS) ADD_LIBRARY(gemm-microkernel-tester STATIC test/gemm-microkernel-tester.cc) TARGET_INCLUDE_DIRECTORIES(gemm-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(gemm-microkernel-tester PRIVATE xnnpack-base fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(gemm-microkernel-tester PRIVATE xnnpack-base pthreadpool GTest::gtest) TARGET_LINK_LIBRARIES(gemm-microkernel-tester PRIVATE packing) IF(XNNPACK_ENABLE_KLEIDIAI) TARGET_LINK_LIBRARIES(gemm-microkernel-tester PRIVATE kleidiai) @@ -1221,34 +1176,34 @@ IF(XNNPACK_BUILD_TESTS) ADD_LIBRARY(unary-operator-tester STATIC test/unary-operator-tester.cc) TARGET_INCLUDE_DIRECTORIES(unary-operator-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(unary-operator-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(unary-operator-tester PRIVATE XNNPACK pthreadpool GTest::gtest) ADD_LIBRARY(dwconv-microkernel-tester STATIC test/dwconv-microkernel-tester.cc) TARGET_INCLUDE_DIRECTORIES(dwconv-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(dwconv-microkernel-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(dwconv-microkernel-tester PRIVATE XNNPACK pthreadpool GTest::gtest) TARGET_LINK_LIBRARIES(dwconv-microkernel-tester PUBLIC next-prime) ADD_LIBRARY(vbinary-microkernel-tester STATIC test/vbinary-microkernel-tester.cc) SET_TARGET_PROPERTIES(vbinary-microkernel-tester PROPERTIES CXX_EXTENSIONS YES) TARGET_INCLUDE_DIRECTORIES(vbinary-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(vbinary-microkernel-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(vbinary-microkernel-tester PRIVATE XNNPACK pthreadpool GTest::gtest) ADD_LIBRARY(vcvt-microkernel-tester STATIC test/vcvt-microkernel-tester.cc) TARGET_INCLUDE_DIRECTORIES(vcvt-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(vcvt-microkernel-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(vcvt-microkernel-tester PRIVATE XNNPACK pthreadpool GTest::gtest) ADD_LIBRARY(vunary-microkernel-tester STATIC test/vunary-microkernel-tester.cc) TARGET_INCLUDE_DIRECTORIES(vunary-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(vunary-microkernel-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(vunary-microkernel-tester PRIVATE XNNPACK pthreadpool GTest::gtest) TARGET_LINK_LIBRARIES(vunary-microkernel-tester PUBLIC next-prime) ADD_LIBRARY(convolution-test-helpers OBJECT test/convolution-test-helpers.cc) TARGET_INCLUDE_DIRECTORIES(convolution-test-helpers PRIVATE include src) - TARGET_LINK_LIBRARIES(convolution-test-helpers PRIVATE xnnpack-base fp16) + TARGET_LINK_LIBRARIES(convolution-test-helpers PRIVATE xnnpack-base) ADD_LIBRARY(packq-microkernel-tester STATIC test/packq-microkernel-tester.cc) TARGET_INCLUDE_DIRECTORIES(packq-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(packq-microkernel-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(packq-microkernel-tester PRIVATE XNNPACK pthreadpool GTest::gtest) IF(XNNPACK_ENABLE_KLEIDIAI) TARGET_LINK_LIBRARIES(packq-microkernel-tester PRIVATE kleidiai) ENDIF() @@ -1268,7 +1223,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gtest GTest::gtest_main hardware-config @@ -1295,7 +1249,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1333,7 +1286,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gtest GTest::gtest_main unary-operator-tester @@ -1344,7 +1296,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(binary-elementwise-nd-test test/binary-elementwise-nd.cc) TARGET_INCLUDE_DIRECTORIES(binary-elementwise-nd-test PRIVATE src test) TARGET_LINK_LIBRARIES(binary-elementwise-nd-test PRIVATE - fp16 GTest::gtest GTest::gtest_main XNNPACK) @@ -1362,7 +1313,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1425,7 +1375,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1444,7 +1393,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE convolution-test-helpers - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1539,7 +1487,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1580,7 +1527,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE dwconv-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1633,7 +1579,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE gemm-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1654,7 +1599,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE packq-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1727,7 +1671,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE vbinary-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1760,7 +1703,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE vcvt-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1814,7 +1756,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE vunary-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1849,7 +1790,7 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(operator-utils-test test/operator-utils.cc) TARGET_INCLUDE_DIRECTORIES(operator-utils-test PRIVATE include src) - TARGET_LINK_LIBRARIES(operator-utils-test PRIVATE XNNPACK GTest::gtest GTest::gtest_main pthreadpool fp16) + TARGET_LINK_LIBRARIES(operator-utils-test PRIVATE XNNPACK GTest::gtest GTest::gtest_main pthreadpool) ENDIF() @@ -1903,14 +1844,14 @@ IF(XNNPACK_BUILD_BENCHMARKS) # Helper libraries ADD_LIBRARY(packq-benchmark STATIC bench/packq-benchmark.cc) TARGET_INCLUDE_DIRECTORIES(packq-benchmark PRIVATE . include src bench) - TARGET_LINK_LIBRARIES(packq-benchmark PRIVATE XNNPACK benchmark::benchmark bench-utils fp16) + TARGET_LINK_LIBRARIES(packq-benchmark PRIVATE XNNPACK benchmark::benchmark bench-utils) IF(XNNPACK_ENABLE_KLEIDIAI) TARGET_LINK_LIBRARIES(packq-benchmark PRIVATE kleidiai) ENDIF() ADD_LIBRARY(gemm-benchmark STATIC bench/gemm-benchmark.cc) TARGET_INCLUDE_DIRECTORIES(gemm-benchmark PRIVATE . include src bench) - TARGET_LINK_LIBRARIES(gemm-benchmark PRIVATE XNNPACK benchmark::benchmark bench-utils fp16) + TARGET_LINK_LIBRARIES(gemm-benchmark PRIVATE XNNPACK benchmark::benchmark bench-utils) IF(XNNPACK_ENABLE_KLEIDIAI) TARGET_LINK_LIBRARIES(gemm-benchmark PUBLIC kleidiai) ENDIF() @@ -1936,7 +1877,6 @@ IF(XNNPACK_BUILD_BENCHMARKS) TARGET_LINK_LIBRARIES(bench-models PRIVATE bench-utils benchmark::benchmark - fp16 models XNNPACK) @@ -1971,7 +1911,6 @@ IF(XNNPACK_BUILD_BENCHMARKS) TARGET_LINK_LIBRARIES(${BENCH}-bench PRIVATE bench-utils benchmark::benchmark - fp16 XNNPACK ) ENDFOREACH() @@ -2068,7 +2007,6 @@ IF(XNNPACK_BUILD_BENCHMARKS) TARGET_LINK_LIBRARIES(${BENCH}-bench PRIVATE bench-utils benchmark::benchmark - fp16 gemm-benchmark hardware-config im2col diff --git a/WORKSPACE b/WORKSPACE index 54d3841939e..b140e022e05 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -63,19 +63,6 @@ http_archive( ) # LINT.ThenChange(cmake/DownloadGoogleBenchmark.cmake) -# LINT.IfChange -# FP16 library, used for half-precision conversions -http_archive( - name = "FP16", - build_file = "@//third_party:FP16.BUILD", - sha256 = "e66e65515fa09927b348d3d584c68be4215cfe664100d01c9dbc7655a5716d70", - strip_prefix = "FP16-0a92994d729ff76a58f692d3028ca1b64b145d91", - urls = [ - "https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip", - ], -) -# LINT.ThenChange(cmake/DownloadFP16.cmake) - # LINT.IfChange # FXdiv library, used for repeated integer division by the same factor http_archive( diff --git a/bench/BUILD.bazel b/bench/BUILD.bazel index a8758ea9539..fbe9444eadf 100644 --- a/bench/BUILD.bazel +++ b/bench/BUILD.bazel @@ -27,7 +27,6 @@ load( MICROKERNEL_BENCHMARK_DEPS = [ ":bench_utils", - "@FP16", "//:aligned_allocator", "//:all_microkernels", "//:common", @@ -48,7 +47,6 @@ OPERATOR_BENCHMARK_DEPS = [ "//:cache", "//:common", "//:math", - "@FP16", ] xnnpack_cxx_library( diff --git a/bench/gemm-benchmark.cc b/bench/gemm-benchmark.cc index 2a9b6576529..116acbfdbd2 100644 --- a/bench/gemm-benchmark.cc +++ b/bench/gemm-benchmark.cc @@ -25,7 +25,6 @@ #include #include -#include #include "bench/utils.h" #include diff --git a/build_params.bzl b/build_params.bzl index 6aec1584971..92b4025364d 100644 --- a/build_params.bzl +++ b/build_params.bzl @@ -274,7 +274,6 @@ XNNPACK_PARAMS_FOR_ARCH = { ], extra_deps = [ "//:config_hdrs", - "@FP16", "@FXdiv", ], ), @@ -523,7 +522,6 @@ XNNPACK_PARAMS_FOR_ARCH = { "-mno-sse4.2", ], extra_deps = [ - "@FP16", ], msvc_x86_32_copts = ["/arch:SSE2"], msvc_x86_64_copts = ["/arch:SSE2"], diff --git a/cmake/DownloadFP16.cmake b/cmake/DownloadFP16.cmake deleted file mode 100644 index a3321d9d40b..00000000000 --- a/cmake/DownloadFP16.cmake +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# Copyright 2019 Google LLC -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -CMAKE_MINIMUM_REQUIRED(VERSION 3.5 FATAL_ERROR) - -PROJECT(fp16-download NONE) - -# Set file timestamps to the time of extraction. -IF(POLICY CMP0135) - CMAKE_POLICY(SET CMP0135 NEW) -ENDIF() - -INCLUDE(ExternalProject) -ExternalProject_Add(fp16 - URL https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip - URL_HASH SHA256=e66e65515fa09927b348d3d584c68be4215cfe664100d01c9dbc7655a5716d70 - SOURCE_DIR "${CMAKE_BINARY_DIR}/FP16-source" - BINARY_DIR "${CMAKE_BINARY_DIR}/FP16" - CONFIGURE_COMMAND "" - BUILD_COMMAND "" - INSTALL_COMMAND "" - TEST_COMMAND "" -) diff --git a/cmake/gen/avxvnniint8_microkernels.cmake b/cmake/gen/avxvnniint8_microkernels.cmake index 60e0412af5f..ee2ac244a81 100644 --- a/cmake/gen/avxvnniint8_microkernels.cmake +++ b/cmake/gen/avxvnniint8_microkernels.cmake @@ -15,6 +15,7 @@ SET(PROD_AVXVNNIINT8_MICROKERNEL_SRCS src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8c8-minmax-fp32-avxvnniint8-prfm.c src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x8c8-minmax-fp32-avxvnniint8-prfm.c) -SET(NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS) +SET(NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS + src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c) SET(ALL_AVXVNNIINT8_MICROKERNEL_SRCS ${PROD_AVXVNNIINT8_MICROKERNEL_SRCS} + ${NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS}) diff --git a/cmake/gen/rvvfp16arith_microkernels.cmake b/cmake/gen/rvvfp16arith_microkernels.cmake index 0f32a4a46ae..03df4b84f53 100644 --- a/cmake/gen/rvvfp16arith_microkernels.cmake +++ b/cmake/gen/rvvfp16arith_microkernels.cmake @@ -15,6 +15,10 @@ SET(NON_PROD_RVVFP16ARITH_MICROKERNEL_SRCS src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u1v.c src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u2v.c src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u4v.c - src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u8v.c) + src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u8v.c + src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c + src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c + src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c + src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c) SET(ALL_RVVFP16ARITH_MICROKERNEL_SRCS ${PROD_RVVFP16ARITH_MICROKERNEL_SRCS} + ${NON_PROD_RVVFP16ARITH_MICROKERNEL_SRCS}) diff --git a/cmake/gen/scalar_microkernels.cmake b/cmake/gen/scalar_microkernels.cmake index d30bceb4257..8879af0b608 100644 --- a/cmake/gen/scalar_microkernels.cmake +++ b/cmake/gen/scalar_microkernels.cmake @@ -635,7 +635,9 @@ SET(NON_PROD_SCALAR_MICROKERNEL_SRCS src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-lrintf-c2.c src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-lrintf-c4.c src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-scalar.c + src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c src/qs8-packw/gen/qs8-packw-x16c4-gemm-goi-scalar.c + src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-4p2c-minmax-fp32-scalar-imagic.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-5f5m5l1c1s1r-minmax-fp32-scalar-fmagic.c diff --git a/gen/avxvnniint8_microkernels.bzl b/gen/avxvnniint8_microkernels.bzl index 97f1657ee4a..a1b149b017c 100644 --- a/gen/avxvnniint8_microkernels.bzl +++ b/gen/avxvnniint8_microkernels.bzl @@ -13,6 +13,7 @@ PROD_AVXVNNIINT8_MICROKERNEL_SRCS = [ ] NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS = [ + "src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c", ] ALL_AVXVNNIINT8_MICROKERNEL_SRCS = PROD_AVXVNNIINT8_MICROKERNEL_SRCS + NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS diff --git a/gen/rvvfp16arith_microkernels.bzl b/gen/rvvfp16arith_microkernels.bzl index 139aa872254..3d0d9afc559 100644 --- a/gen/rvvfp16arith_microkernels.bzl +++ b/gen/rvvfp16arith_microkernels.bzl @@ -13,6 +13,10 @@ NON_PROD_RVVFP16ARITH_MICROKERNEL_SRCS = [ "src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u2v.c", "src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u4v.c", "src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u8v.c", + "src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c", + "src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c", + "src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c", + "src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c", ] ALL_RVVFP16ARITH_MICROKERNEL_SRCS = PROD_RVVFP16ARITH_MICROKERNEL_SRCS + NON_PROD_RVVFP16ARITH_MICROKERNEL_SRCS diff --git a/gen/scalar_microkernels.bzl b/gen/scalar_microkernels.bzl index 7ef2fc7f766..21e3b9f9b0c 100644 --- a/gen/scalar_microkernels.bzl +++ b/gen/scalar_microkernels.bzl @@ -632,7 +632,9 @@ NON_PROD_SCALAR_MICROKERNEL_SRCS = [ "src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-lrintf-c2.c", "src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-lrintf-c4.c", "src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-scalar.c", + "src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c", "src/qs8-packw/gen/qs8-packw-x16c4-gemm-goi-scalar.c", + "src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c", "src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-4p2c-minmax-fp32-scalar-imagic.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-5f5m5l1c1s1r-minmax-fp32-scalar-fmagic.c", diff --git a/scripts/generate-f32-f16-vcvt.sh b/scripts/generate-f32-f16-vcvt.sh index e34f227224f..d7dae496c53 100755 --- a/scripts/generate-f32-f16-vcvt.sh +++ b/scripts/generate-f32-f16-vcvt.sh @@ -13,6 +13,12 @@ tools/xngen src/f32-f16-vcvt/neon.c.in -D BATCH_TILE=32 -o src/f32-f16-vcvt/gen/ tools/xngen src/f32-f16-vcvt/neonfp16.c.in -D BATCH_TILE=8 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-neonfp16-u8.c & tools/xngen src/f32-f16-vcvt/neonfp16.c.in -D BATCH_TILE=16 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-neonfp16-u16.c & +################################ RISC-V Vector ################################ +tools/xngen src/f32-f16-vcvt/rvvfp16arith.c.in -D LMUL=1 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c & +tools/xngen src/f32-f16-vcvt/rvvfp16arith.c.in -D LMUL=2 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c & +tools/xngen src/f32-f16-vcvt/rvvfp16arith.c.in -D LMUL=4 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c & +tools/xngen src/f32-f16-vcvt/rvvfp16arith.c.in -D LMUL=8 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c & + ################################# x86 128-bit ################################# tools/xngen src/f32-f16-vcvt/sse.c.in -D SSE=2 -D AVX=0 -D BATCH_TILE=8 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-sse2-u8.c & tools/xngen src/f32-f16-vcvt/sse.c.in -D SSE=2 -D AVX=0 -D BATCH_TILE=16 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-sse2-u16.c & diff --git a/scripts/generate-x8-packw.sh b/scripts/generate-x8-packw.sh index 281bba59237..912f9400796 100755 --- a/scripts/generate-x8-packw.sh +++ b/scripts/generate-x8-packw.sh @@ -23,4 +23,11 @@ tools/xngen src/x8-packw/kr-scalar.c.in -D NR=16 -D KR=4 -D TYPE=int8_t -o src/q tools/xngen src/x8-packw/kr-scalar.c.in -D NR=32 -D KR=4 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c & tools/xngen src/x8-packw/kr-scalar.c.in -D NR=64 -D KR=4 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-scalar.c & + +### C8 packing +tools/xngen src/x8-packw/kr-scalar.c.in -D NR=8 -D KR=8 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c & +tools/xngen src/x8-packw/kr-scalar.c.in -D NR=16 -D KR=8 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c & + +tools/xngen src/x8-packw/kr-avxvnniint8.c.in -D NR=8 -D KR=8 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c & + wait diff --git a/scripts/genxnn b/scripts/genxnn new file mode 100755 index 00000000000..12a5888c8c9 --- /dev/null +++ b/scripts/genxnn @@ -0,0 +1,73 @@ +#!/bin/bash + +function my_generate1() { + files_changed=$(scripts/check_files_changed.py $1) + if [[ $files_changed ]] + then + if ${VERBOSE}; then + echo $1 + fi + if ${DEBUG}; then + echo $files_changed + fi + touch $files_changed + bash -c $1 & + fi +} + +function my_generatef() { + if ${VERBOSE}; then + echo $1 + fi + bash -c $1 & +} + +export -f my_generate1 +export -f my_generatef + +export FORCE='false' +export VERBOSE='false' +export DEBUG='false' + +while getopts ':vfd' 'OPTKEY'; do + case ${OPTKEY} in + 'v') + export VERBOSE='true' + ;; + 'd') + export DEBUG='true' + ;; + 'f') + export FORCE='true' + ;; + '?') + echo "INVALID OPTION -- ${OPTARG}" >&2 + exit 1 + ;; + ':') + echo "MISSING ARGUMENT for option -- ${OPTARG}" >&2 + exit 1 + ;; + *) + echo "UNIMPLEMENTED OPTION -- ${OPTKEY}" >&2 + exit 1 + ;; + esac +done + +# [optional] Remove all options processed by getopts. +shift $(( OPTIND - 1 )) +[[ "${1}" == "--" ]] && shift + +#pushd $(g4 g4d) +if ${FORCE}; then + find scripts/ -name 'generate-*.sh' -exec bash -c 'my_generatef {}' \; +else + find scripts/ -name 'generate-*.sh' -exec bash -c 'my_generate1 {}' \; +fi +wait +if ${VERBOSE}; then + echo ./tools/update-microkernels.py +fi +./tools/update-microkernels.py +#popd diff --git a/src/configs/unary-elementwise-config.c b/src/configs/unary-elementwise-config.c index b464b1af0ba..26a7d8250d6 100644 --- a/src/configs/unary-elementwise-config.c +++ b/src/configs/unary-elementwise-config.c @@ -1999,7 +1999,7 @@ static void init_qs8_lrelu_config(void) { } static void init_qs8_to_f16_cvt_config(void) { - #if XNN_ARCH_ARM || XNN_ARCH_ARM64 + #if (XNN_ARCH_ARM || XNN_ARCH_ARM64) && XNN_ENABLE_ARM_FP16_VECTOR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); if (hardware_config->use_arm_neon_fp16_arith) { diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u1.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u1.c index 20469b41d53..fb9e4dc17b6 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u1.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u1.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_fmagic_u1( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u2.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u2.c index 017b78406bf..e3561a3ccb3 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u2.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u2.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_fmagic_u2( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u3.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u3.c index c88ab86b633..f088ddf5684 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u3.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u3.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_fmagic_u3( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u4.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u4.c index c9bd1b86128..c2cb7a3594e 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u4.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u4.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_fmagic_u4( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u1.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u1.c index 7d9bda6d202..c995a3df11a 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u1.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u1.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_imagic_u1( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u2.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u2.c index 82074e6fcaf..00620ce80b3 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u2.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u2.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_imagic_u2( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u3.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u3.c index 3ccb4aaff00..c4f849e687e 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u3.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u3.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_imagic_u3( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u4.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u4.c index 979725bdce9..86e42c4ca57 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u4.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u4.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_imagic_u4( size_t batch, diff --git a/src/f32-f16-vcvt/f32-f16-vcvt.h b/src/f32-f16-vcvt/f32-f16-vcvt.h index dfd6fb266f5..fbcaa08526d 100644 --- a/src/f32-f16-vcvt/f32-f16-vcvt.h +++ b/src/f32-f16-vcvt/f32-f16-vcvt.h @@ -58,6 +58,13 @@ XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__wasmrelaxedsimd_u24, 24 XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__wasmrelaxedsimd_u32, 32, false, float, xnn_float16, void, NULL) #endif // XNN_ARCH_WASMRELAXEDSIMD +#if XNN_ARCH_RISCV && XNN_ENABLE_RISCV_FP16_VECTOR +XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__rvv_u1v, 1, true, float, xnn_float16, void, NULL) +XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__rvv_u2v, 2, true, float, xnn_float16, void, NULL) +XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__rvv_u4v, 4, true, float, xnn_float16, void, NULL) +XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__rvv_u8v, 8, true, float, xnn_float16, void, NULL) +#endif + XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__scalar_bitcast_u1, 1, false, float, xnn_float16, void, NULL) XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__scalar_bitcast_u2, 2, false, float, xnn_float16, void, NULL) XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__scalar_bitcast_u3, 3, false, float, xnn_float16, void, NULL) diff --git a/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c new file mode 100644 index 00000000000..e84b6d04699 --- /dev/null +++ b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c @@ -0,0 +1,40 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-f16-vcvt/rvvfp16arith.c.in +// Generator: tools/xngen +// +// Copyright 2024 Imagination Technologies, Inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include "xnnpack/vcvt.h" + + +void xnn_f32_f16_vcvt_ukernel__rvvfp16arith_u1v( + size_t batch, + const float* input, + void* output, + const void* params) +{ + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_FLOAT; + + _Float16* o = (_Float16*) output; + for (; batch > 0;) { + const int32_t n = __riscv_vsetvl_e32m1(batch); batch -= n; + + vfloat32m1_t x_f32v = __riscv_vle32_v_f32m1(input, n); input += n; + + vfloat16mf2_t y_f16v = __riscv_vfncvt_f_f_w_f16mf2(x_f32v, n); + + __riscv_vse16_v_f16mf2(o, y_f16v, n); o += n; + } +} diff --git a/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c new file mode 100644 index 00000000000..69dd22f841e --- /dev/null +++ b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c @@ -0,0 +1,40 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-f16-vcvt/rvvfp16arith.c.in +// Generator: tools/xngen +// +// Copyright 2024 Imagination Technologies, Inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include "xnnpack/vcvt.h" + + +void xnn_f32_f16_vcvt_ukernel__rvvfp16arith_u2v( + size_t batch, + const float* input, + void* output, + const void* params) +{ + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_FLOAT; + + _Float16* o = (_Float16*) output; + for (; batch > 0;) { + const int32_t n = __riscv_vsetvl_e32m2(batch); batch -= n; + + vfloat32m2_t x_f32v = __riscv_vle32_v_f32m2(input, n); input += n; + + vfloat16m1_t y_f16v = __riscv_vfncvt_f_f_w_f16m1(x_f32v, n); + + __riscv_vse16_v_f16m1(o, y_f16v, n); o += n; + } +} diff --git a/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c new file mode 100644 index 00000000000..116bb4d35bb --- /dev/null +++ b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c @@ -0,0 +1,40 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-f16-vcvt/rvvfp16arith.c.in +// Generator: tools/xngen +// +// Copyright 2024 Imagination Technologies, Inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include "xnnpack/vcvt.h" + + +void xnn_f32_f16_vcvt_ukernel__rvvfp16arith_u4v( + size_t batch, + const float* input, + void* output, + const void* params) +{ + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_FLOAT; + + _Float16* o = (_Float16*) output; + for (; batch > 0;) { + const int32_t n = __riscv_vsetvl_e32m4(batch); batch -= n; + + vfloat32m4_t x_f32v = __riscv_vle32_v_f32m4(input, n); input += n; + + vfloat16m2_t y_f16v = __riscv_vfncvt_f_f_w_f16m2(x_f32v, n); + + __riscv_vse16_v_f16m2(o, y_f16v, n); o += n; + } +} diff --git a/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c new file mode 100644 index 00000000000..dc0cadfb239 --- /dev/null +++ b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c @@ -0,0 +1,40 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-f16-vcvt/rvvfp16arith.c.in +// Generator: tools/xngen +// +// Copyright 2024 Imagination Technologies, Inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include "xnnpack/vcvt.h" + + +void xnn_f32_f16_vcvt_ukernel__rvvfp16arith_u8v( + size_t batch, + const float* input, + void* output, + const void* params) +{ + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_FLOAT; + + _Float16* o = (_Float16*) output; + for (; batch > 0;) { + const int32_t n = __riscv_vsetvl_e32m8(batch); batch -= n; + + vfloat32m8_t x_f32v = __riscv_vle32_v_f32m8(input, n); input += n; + + vfloat16m4_t y_f16v = __riscv_vfncvt_f_f_w_f16m4(x_f32v, n); + + __riscv_vse16_v_f16m4(o, y_f16v, n); o += n; + } +} diff --git a/src/f32-f16-vcvt/rvvfp16arith.c.in b/src/f32-f16-vcvt/rvvfp16arith.c.in new file mode 100644 index 00000000000..da136047268 --- /dev/null +++ b/src/f32-f16-vcvt/rvvfp16arith.c.in @@ -0,0 +1,38 @@ +// Copyright 2024 Imagination Technologies, Inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$assert LMUL in [1, 2, 4, 8] +$LMUL_16 = {1: "f2", 2: "1", 4: "2", 8: "4"}[LMUL] +#include + +#include + +#include "xnnpack/vcvt.h" + + +void xnn_f32_f16_vcvt_ukernel__rvvfp16arith_u${LMUL}v( + size_t batch, + const float* input, + void* output, + const void* params) +{ + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_FLOAT; + + _Float16* o = (_Float16*) output; + for (; batch > 0;) { + const int32_t n = __riscv_vsetvl_e32m${LMUL}(batch); batch -= n; + + vfloat32m${LMUL}_t x_f32v = __riscv_vle32_v_f32m${LMUL}(input, n); input += n; + + vfloat16m${LMUL_16}_t y_f16v = __riscv_vfncvt_f_f_w_f16m${LMUL_16}(x_f32v, n); + + __riscv_vse16_v_f16m${LMUL_16}(o, y_f16v, n); o += n; + } +} diff --git a/src/f32-qs8-vcvt/scalar-fmagic.c.in b/src/f32-qs8-vcvt/scalar-fmagic.c.in index 7c3bfdfa0ba..94218c1745a 100644 --- a/src/f32-qs8-vcvt/scalar-fmagic.c.in +++ b/src/f32-qs8-vcvt/scalar-fmagic.c.in @@ -12,8 +12,6 @@ $assert IDATATYPE == "F16" and ODATATYPE == "QS8" or IDATATYPE == "F32" #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -$if IDATATYPE == "F16": - #include $INPUT_T = {"F16": "xnn_float16", "F32": "float"}[IDATATYPE] $XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[ODATATYPE] diff --git a/src/f32-qs8-vcvt/scalar-imagic.c.in b/src/f32-qs8-vcvt/scalar-imagic.c.in index ae1e474b3ec..2d83fb24ed0 100644 --- a/src/f32-qs8-vcvt/scalar-imagic.c.in +++ b/src/f32-qs8-vcvt/scalar-imagic.c.in @@ -12,8 +12,6 @@ $assert IDATATYPE == "F16" and ODATATYPE == "QS8" or IDATATYPE == "F32" #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -$if IDATATYPE == "F16": - #include $INPUT_T = {"F16": "xnn_float16", "F32": "float"}[IDATATYPE] $XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[ODATATYPE] diff --git a/src/f32-rminmax/scalar.c.in b/src/f32-rminmax/scalar.c.in index e26a6584b6c..1fbf58ff951 100644 --- a/src/f32-rminmax/scalar.c.in +++ b/src/f32-rminmax/scalar.c.in @@ -12,7 +12,7 @@ $if not WASM: #include "xnnpack/math.h" #include "xnnpack/reduce.h" $if DATATYPE == "F16": - #include + #include "xnnpack/fp16.h" $ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS $MAX = "__builtin_wasm_max_f32" if WASM else "math_max_f32" @@ -30,7 +30,7 @@ void xnn_${DATATYPE.lower()}_r${OP.lower()}_ukernel__${ISA}_u${BATCH_TILE}${ACC_ $elif DATATYPE == "F16": const void* input, void* output, - const union xnn_${DATATYPE.lower()}_default_params params[restrict XNN_MIN_ELEMENTS(1)]) + const struct xnn_${DATATYPE.lower()}_default_params params[restrict XNN_MIN_ELEMENTS(1)]) { assert(batch != 0); assert(batch % sizeof(${ITYPE}) == 0); diff --git a/src/operators/average-pooling-nhwc.c b/src/operators/average-pooling-nhwc.c index 754da020202..7a217c37d85 100644 --- a/src/operators/average-pooling-nhwc.c +++ b/src/operators/average-pooling-nhwc.c @@ -14,7 +14,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/operators/convolution-nchw.c b/src/operators/convolution-nchw.c index fc4fe8c9d5b..6be0bf5a2b1 100644 --- a/src/operators/convolution-nchw.c +++ b/src/operators/convolution-nchw.c @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/cache.h" diff --git a/src/operators/convolution-nhwc.c b/src/operators/convolution-nhwc.c index 0125b6cb3cb..e0eb6265ccb 100644 --- a/src/operators/convolution-nhwc.c +++ b/src/operators/convolution-nhwc.c @@ -14,7 +14,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/cache.h" diff --git a/src/operators/deconvolution-nhwc.c b/src/operators/deconvolution-nhwc.c index 60a44bc158b..ddfe9650723 100644 --- a/src/operators/deconvolution-nhwc.c +++ b/src/operators/deconvolution-nhwc.c @@ -13,7 +13,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/cache.h" diff --git a/src/operators/dynamic-fully-connected-nc.c b/src/operators/dynamic-fully-connected-nc.c index ca5d69fdc6f..c3eb10fab03 100644 --- a/src/operators/dynamic-fully-connected-nc.c +++ b/src/operators/dynamic-fully-connected-nc.c @@ -9,7 +9,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/operators/global-average-pooling-ncw.c b/src/operators/global-average-pooling-ncw.c index ef5ce1a7749..6e7c3700200 100644 --- a/src/operators/global-average-pooling-ncw.c +++ b/src/operators/global-average-pooling-ncw.c @@ -10,13 +10,13 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" #include "xnnpack/compute.h" #include "xnnpack/config-types.h" #include "xnnpack/config.h" +#include "xnnpack/fp16.h" #include "xnnpack/log.h" #include "xnnpack/math.h" #include "xnnpack/microparams-init.h" diff --git a/src/operators/global-average-pooling-nwc.c b/src/operators/global-average-pooling-nwc.c index fa50945dc5b..e04b7943258 100644 --- a/src/operators/global-average-pooling-nwc.c +++ b/src/operators/global-average-pooling-nwc.c @@ -14,13 +14,13 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" #include "xnnpack/compute.h" #include "xnnpack/config-types.h" #include "xnnpack/config.h" +#include "xnnpack/fp16.h" #include "xnnpack/log.h" #include "xnnpack/math.h" #include "xnnpack/microparams.h" diff --git a/src/operators/max-pooling-nhwc.c b/src/operators/max-pooling-nhwc.c index 43ad86552cd..e8dd7413d37 100644 --- a/src/operators/max-pooling-nhwc.c +++ b/src/operators/max-pooling-nhwc.c @@ -14,7 +14,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/operators/scaled-dot-product-attention-nhtc.c b/src/operators/scaled-dot-product-attention-nhtc.c index 1de22ee1d87..a278c7c044f 100644 --- a/src/operators/scaled-dot-product-attention-nhtc.c +++ b/src/operators/scaled-dot-product-attention-nhtc.c @@ -9,7 +9,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/operators/softmax-nc.c b/src/operators/softmax-nc.c index 0d4db7f1443..75e57edc9c9 100644 --- a/src/operators/softmax-nc.c +++ b/src/operators/softmax-nc.c @@ -14,7 +14,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/operators/unary-elementwise-nc.c b/src/operators/unary-elementwise-nc.c index 0cd6baa9955..2270934695a 100644 --- a/src/operators/unary-elementwise-nc.c +++ b/src/operators/unary-elementwise-nc.c @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/packing.cc b/src/packing.cc index 9c742ad6e8c..c37b06b5902 100644 --- a/src/packing.cc +++ b/src/packing.cc @@ -28,7 +28,6 @@ #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" #endif // XNN_ENABLE_KLEIDIAI -#include extern "C" { diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c index 80836d13ca5..5be5406c57f 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c @@ -12,7 +12,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" #include "xnnpack/unaligned.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x2__scalar( @@ -98,11 +97,11 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x2__scalar( w = (const float*) w + 2; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c index b84b84c0f1a..215aa89d397 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c @@ -11,7 +11,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x4__scalar( @@ -127,13 +126,13 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x4__scalar( w = (const float*) w + 4; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); vout0x2 = math_max_f32(vout0x2, voutput_min); vout0x3 = math_max_f32(vout0x3, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); vout0x2 = math_min_f32(vout0x2, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c index 73bf59de5a0..d42717834c3 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c @@ -11,7 +11,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x8__scalar( @@ -187,7 +186,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x8__scalar( w = (const float*) w + 8; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); vout0x2 = math_max_f32(vout0x2, voutput_min); @@ -197,7 +196,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x8__scalar( vout0x6 = math_max_f32(vout0x6, voutput_min); vout0x7 = math_max_f32(vout0x7, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); vout0x2 = math_min_f32(vout0x2, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x2-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x2-minmax-scalar.c index 660ea399c00..60f32047731 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x2-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x2-minmax-scalar.c @@ -12,7 +12,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" #include "xnnpack/unaligned.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x2__scalar( @@ -127,13 +126,13 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x2__scalar( w = (const float*) w + 2; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout1x0 = math_max_f32(vout1x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); vout1x1 = math_max_f32(vout1x1, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout1x0 = math_min_f32(vout1x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x4-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x4-minmax-scalar.c index 580148c13b1..a75b95cc64f 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x4-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x4-minmax-scalar.c @@ -11,7 +11,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x4__scalar( @@ -174,7 +173,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x4__scalar( w = (const float*) w + 4; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout1x0 = math_max_f32(vout1x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); @@ -184,7 +183,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x4__scalar( vout0x3 = math_max_f32(vout0x3, voutput_min); vout1x3 = math_max_f32(vout1x3, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout1x0 = math_min_f32(vout1x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8-minmax-scalar.c index c5edd90fbb5..523d285b7e2 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8-minmax-scalar.c @@ -11,7 +11,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x8__scalar( @@ -270,7 +269,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x8__scalar( w = (const float*) w + 8; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout1x0 = math_max_f32(vout1x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); @@ -288,7 +287,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x8__scalar( vout0x7 = math_max_f32(vout0x7, voutput_min); vout1x7 = math_max_f32(vout1x7, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout1x0 = math_min_f32(vout1x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x4-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x4-minmax-scalar.c index 33b41fa79e5..e8855d10d4a 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x4-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x4-minmax-scalar.c @@ -11,7 +11,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x4__scalar( @@ -268,7 +267,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x4__scalar( w = (const float*) w + 4; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout1x0 = math_max_f32(vout1x0, voutput_min); vout2x0 = math_max_f32(vout2x0, voutput_min); @@ -286,7 +285,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x4__scalar( vout2x3 = math_max_f32(vout2x3, voutput_min); vout3x3 = math_max_f32(vout3x3, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout1x0 = math_min_f32(vout1x0, voutput_max); vout2x0 = math_min_f32(vout2x0, voutput_max); diff --git a/src/qs8-gemm/scalar.c.in b/src/qs8-gemm/scalar.c.in index fbe2c29b299..97285dd32cb 100644 --- a/src/qs8-gemm/scalar.c.in +++ b/src/qs8-gemm/scalar.c.in @@ -14,8 +14,6 @@ $if VARIANT == "LRINTF": #include "xnnpack/math.h" $if NR % 4 != 0: #include "xnnpack/unaligned.h" -$if DATATYPE in ["QB4_F16"]: - #include $# $INDENT = 0 @@ -279,7 +277,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}_ w = (const float*) w + ${NR * 2}; $if DATATYPE in ["QB4_F16"]: - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); $else: const float voutput_min = params->scalar.min; $for N in range(NR): @@ -287,7 +285,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}_ vout${M}x${N} = ${MAX_F32}(vout${M}x${N}, voutput_min); $if DATATYPE in ["QB4_F16"]: - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); $else: const float voutput_max = params->scalar.max; $for N in range(NR): diff --git a/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c new file mode 100644 index 00000000000..0b7b42c16b3 --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c @@ -0,0 +1,2319 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-scalar.c.in +// Generator: tools/xngen +// +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include "xnnpack/packw.h" + +void xnn_qs8_packw_gemm_goi_ukernel_x16c8__scalar( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + const uint32_t izp = params ? (uint32_t) ((const struct xnn_qs8_packw_params*) params)->input_zero_point : 0; + + do { + // NC main loop multiple of 16 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 16; n -= 16) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + ((int32_t*) out)[0] = b[0]; + ((int32_t*) out)[1] = b[1]; + ((int32_t*) out)[2] = b[2]; + ((int32_t*) out)[3] = b[3]; + ((int32_t*) out)[4] = b[4]; + ((int32_t*) out)[5] = b[5]; + ((int32_t*) out)[6] = b[6]; + ((int32_t*) out)[7] = b[7]; + ((int32_t*) out)[8] = b[8]; + ((int32_t*) out)[9] = b[9]; + ((int32_t*) out)[10] = b[10]; + ((int32_t*) out)[11] = b[11]; + ((int32_t*) out)[12] = b[12]; + ((int32_t*) out)[13] = b[13]; + ((int32_t*) out)[14] = b[14]; + ((int32_t*) out)[15] = b[15]; + b += 16; + } else { + ((int32_t*) out)[0] = 0; + ((int32_t*) out)[1] = 0; + ((int32_t*) out)[2] = 0; + ((int32_t*) out)[3] = 0; + ((int32_t*) out)[4] = 0; + ((int32_t*) out)[5] = 0; + ((int32_t*) out)[6] = 0; + ((int32_t*) out)[7] = 0; + ((int32_t*) out)[8] = 0; + ((int32_t*) out)[9] = 0; + ((int32_t*) out)[10] = 0; + ((int32_t*) out)[11] = 0; + ((int32_t*) out)[12] = 0; + ((int32_t*) out)[13] = 0; + ((int32_t*) out)[14] = 0; + ((int32_t*) out)[15] = 0; + } + out += 16 * sizeof(int32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + uint32_t ksum0 = 0; + uint32_t ksum1 = 0; + uint32_t ksum2 = 0; + uint32_t ksum3 = 0; + uint32_t ksum4 = 0; + uint32_t ksum5 = 0; + uint32_t ksum6 = 0; + uint32_t ksum7 = 0; + uint32_t ksum8 = 0; + uint32_t ksum9 = 0; + uint32_t ksum10 = 0; + uint32_t ksum11 = 0; + uint32_t ksum12 = 0; + uint32_t ksum13 = 0; + uint32_t ksum14 = 0; + uint32_t ksum15 = 0; + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + const int8_t v0x0 = w0[0]; + const int8_t v0x1 = w0[1]; + const int8_t v0x2 = w0[2]; + const int8_t v0x3 = w0[3]; + const int8_t v0x4 = w0[4]; + const int8_t v0x5 = w0[5]; + const int8_t v0x6 = w0[6]; + const int8_t v0x7 = w0[7]; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + out[0] = v0x0; + out[1] = v0x1; + out[2] = v0x2; + out[3] = v0x3; + out[4] = v0x4; + out[5] = v0x5; + out[6] = v0x6; + out[7] = v0x7; + w0 += 8; + const int8_t v1x0 = w1[0]; + const int8_t v1x1 = w1[1]; + const int8_t v1x2 = w1[2]; + const int8_t v1x3 = w1[3]; + const int8_t v1x4 = w1[4]; + const int8_t v1x5 = w1[5]; + const int8_t v1x6 = w1[6]; + const int8_t v1x7 = w1[7]; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + out[8] = v1x0; + out[9] = v1x1; + out[10] = v1x2; + out[11] = v1x3; + out[12] = v1x4; + out[13] = v1x5; + out[14] = v1x6; + out[15] = v1x7; + w1 += 8; + const int8_t v2x0 = w2[0]; + const int8_t v2x1 = w2[1]; + const int8_t v2x2 = w2[2]; + const int8_t v2x3 = w2[3]; + const int8_t v2x4 = w2[4]; + const int8_t v2x5 = w2[5]; + const int8_t v2x6 = w2[6]; + const int8_t v2x7 = w2[7]; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + out[16] = v2x0; + out[17] = v2x1; + out[18] = v2x2; + out[19] = v2x3; + out[20] = v2x4; + out[21] = v2x5; + out[22] = v2x6; + out[23] = v2x7; + w2 += 8; + const int8_t v3x0 = w3[0]; + const int8_t v3x1 = w3[1]; + const int8_t v3x2 = w3[2]; + const int8_t v3x3 = w3[3]; + const int8_t v3x4 = w3[4]; + const int8_t v3x5 = w3[5]; + const int8_t v3x6 = w3[6]; + const int8_t v3x7 = w3[7]; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + out[24] = v3x0; + out[25] = v3x1; + out[26] = v3x2; + out[27] = v3x3; + out[28] = v3x4; + out[29] = v3x5; + out[30] = v3x6; + out[31] = v3x7; + w3 += 8; + const int8_t v4x0 = w4[0]; + const int8_t v4x1 = w4[1]; + const int8_t v4x2 = w4[2]; + const int8_t v4x3 = w4[3]; + const int8_t v4x4 = w4[4]; + const int8_t v4x5 = w4[5]; + const int8_t v4x6 = w4[6]; + const int8_t v4x7 = w4[7]; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + out[32] = v4x0; + out[33] = v4x1; + out[34] = v4x2; + out[35] = v4x3; + out[36] = v4x4; + out[37] = v4x5; + out[38] = v4x6; + out[39] = v4x7; + w4 += 8; + const int8_t v5x0 = w5[0]; + const int8_t v5x1 = w5[1]; + const int8_t v5x2 = w5[2]; + const int8_t v5x3 = w5[3]; + const int8_t v5x4 = w5[4]; + const int8_t v5x5 = w5[5]; + const int8_t v5x6 = w5[6]; + const int8_t v5x7 = w5[7]; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + out[40] = v5x0; + out[41] = v5x1; + out[42] = v5x2; + out[43] = v5x3; + out[44] = v5x4; + out[45] = v5x5; + out[46] = v5x6; + out[47] = v5x7; + w5 += 8; + const int8_t v6x0 = w6[0]; + const int8_t v6x1 = w6[1]; + const int8_t v6x2 = w6[2]; + const int8_t v6x3 = w6[3]; + const int8_t v6x4 = w6[4]; + const int8_t v6x5 = w6[5]; + const int8_t v6x6 = w6[6]; + const int8_t v6x7 = w6[7]; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + out[48] = v6x0; + out[49] = v6x1; + out[50] = v6x2; + out[51] = v6x3; + out[52] = v6x4; + out[53] = v6x5; + out[54] = v6x6; + out[55] = v6x7; + w6 += 8; + const int8_t v7x0 = w7[0]; + const int8_t v7x1 = w7[1]; + const int8_t v7x2 = w7[2]; + const int8_t v7x3 = w7[3]; + const int8_t v7x4 = w7[4]; + const int8_t v7x5 = w7[5]; + const int8_t v7x6 = w7[6]; + const int8_t v7x7 = w7[7]; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + out[56] = v7x0; + out[57] = v7x1; + out[58] = v7x2; + out[59] = v7x3; + out[60] = v7x4; + out[61] = v7x5; + out[62] = v7x6; + out[63] = v7x7; + w7 += 8; + const int8_t v8x0 = w8[0]; + const int8_t v8x1 = w8[1]; + const int8_t v8x2 = w8[2]; + const int8_t v8x3 = w8[3]; + const int8_t v8x4 = w8[4]; + const int8_t v8x5 = w8[5]; + const int8_t v8x6 = w8[6]; + const int8_t v8x7 = w8[7]; + ksum8 += (uint32_t) v8x0; + ksum8 += (uint32_t) v8x1; + ksum8 += (uint32_t) v8x2; + ksum8 += (uint32_t) v8x3; + ksum8 += (uint32_t) v8x4; + ksum8 += (uint32_t) v8x5; + ksum8 += (uint32_t) v8x6; + ksum8 += (uint32_t) v8x7; + out[64] = v8x0; + out[65] = v8x1; + out[66] = v8x2; + out[67] = v8x3; + out[68] = v8x4; + out[69] = v8x5; + out[70] = v8x6; + out[71] = v8x7; + w8 += 8; + const int8_t v9x0 = w9[0]; + const int8_t v9x1 = w9[1]; + const int8_t v9x2 = w9[2]; + const int8_t v9x3 = w9[3]; + const int8_t v9x4 = w9[4]; + const int8_t v9x5 = w9[5]; + const int8_t v9x6 = w9[6]; + const int8_t v9x7 = w9[7]; + ksum9 += (uint32_t) v9x0; + ksum9 += (uint32_t) v9x1; + ksum9 += (uint32_t) v9x2; + ksum9 += (uint32_t) v9x3; + ksum9 += (uint32_t) v9x4; + ksum9 += (uint32_t) v9x5; + ksum9 += (uint32_t) v9x6; + ksum9 += (uint32_t) v9x7; + out[72] = v9x0; + out[73] = v9x1; + out[74] = v9x2; + out[75] = v9x3; + out[76] = v9x4; + out[77] = v9x5; + out[78] = v9x6; + out[79] = v9x7; + w9 += 8; + const int8_t v10x0 = w10[0]; + const int8_t v10x1 = w10[1]; + const int8_t v10x2 = w10[2]; + const int8_t v10x3 = w10[3]; + const int8_t v10x4 = w10[4]; + const int8_t v10x5 = w10[5]; + const int8_t v10x6 = w10[6]; + const int8_t v10x7 = w10[7]; + ksum10 += (uint32_t) v10x0; + ksum10 += (uint32_t) v10x1; + ksum10 += (uint32_t) v10x2; + ksum10 += (uint32_t) v10x3; + ksum10 += (uint32_t) v10x4; + ksum10 += (uint32_t) v10x5; + ksum10 += (uint32_t) v10x6; + ksum10 += (uint32_t) v10x7; + out[80] = v10x0; + out[81] = v10x1; + out[82] = v10x2; + out[83] = v10x3; + out[84] = v10x4; + out[85] = v10x5; + out[86] = v10x6; + out[87] = v10x7; + w10 += 8; + const int8_t v11x0 = w11[0]; + const int8_t v11x1 = w11[1]; + const int8_t v11x2 = w11[2]; + const int8_t v11x3 = w11[3]; + const int8_t v11x4 = w11[4]; + const int8_t v11x5 = w11[5]; + const int8_t v11x6 = w11[6]; + const int8_t v11x7 = w11[7]; + ksum11 += (uint32_t) v11x0; + ksum11 += (uint32_t) v11x1; + ksum11 += (uint32_t) v11x2; + ksum11 += (uint32_t) v11x3; + ksum11 += (uint32_t) v11x4; + ksum11 += (uint32_t) v11x5; + ksum11 += (uint32_t) v11x6; + ksum11 += (uint32_t) v11x7; + out[88] = v11x0; + out[89] = v11x1; + out[90] = v11x2; + out[91] = v11x3; + out[92] = v11x4; + out[93] = v11x5; + out[94] = v11x6; + out[95] = v11x7; + w11 += 8; + const int8_t v12x0 = w12[0]; + const int8_t v12x1 = w12[1]; + const int8_t v12x2 = w12[2]; + const int8_t v12x3 = w12[3]; + const int8_t v12x4 = w12[4]; + const int8_t v12x5 = w12[5]; + const int8_t v12x6 = w12[6]; + const int8_t v12x7 = w12[7]; + ksum12 += (uint32_t) v12x0; + ksum12 += (uint32_t) v12x1; + ksum12 += (uint32_t) v12x2; + ksum12 += (uint32_t) v12x3; + ksum12 += (uint32_t) v12x4; + ksum12 += (uint32_t) v12x5; + ksum12 += (uint32_t) v12x6; + ksum12 += (uint32_t) v12x7; + out[96] = v12x0; + out[97] = v12x1; + out[98] = v12x2; + out[99] = v12x3; + out[100] = v12x4; + out[101] = v12x5; + out[102] = v12x6; + out[103] = v12x7; + w12 += 8; + const int8_t v13x0 = w13[0]; + const int8_t v13x1 = w13[1]; + const int8_t v13x2 = w13[2]; + const int8_t v13x3 = w13[3]; + const int8_t v13x4 = w13[4]; + const int8_t v13x5 = w13[5]; + const int8_t v13x6 = w13[6]; + const int8_t v13x7 = w13[7]; + ksum13 += (uint32_t) v13x0; + ksum13 += (uint32_t) v13x1; + ksum13 += (uint32_t) v13x2; + ksum13 += (uint32_t) v13x3; + ksum13 += (uint32_t) v13x4; + ksum13 += (uint32_t) v13x5; + ksum13 += (uint32_t) v13x6; + ksum13 += (uint32_t) v13x7; + out[104] = v13x0; + out[105] = v13x1; + out[106] = v13x2; + out[107] = v13x3; + out[108] = v13x4; + out[109] = v13x5; + out[110] = v13x6; + out[111] = v13x7; + w13 += 8; + const int8_t v14x0 = w14[0]; + const int8_t v14x1 = w14[1]; + const int8_t v14x2 = w14[2]; + const int8_t v14x3 = w14[3]; + const int8_t v14x4 = w14[4]; + const int8_t v14x5 = w14[5]; + const int8_t v14x6 = w14[6]; + const int8_t v14x7 = w14[7]; + ksum14 += (uint32_t) v14x0; + ksum14 += (uint32_t) v14x1; + ksum14 += (uint32_t) v14x2; + ksum14 += (uint32_t) v14x3; + ksum14 += (uint32_t) v14x4; + ksum14 += (uint32_t) v14x5; + ksum14 += (uint32_t) v14x6; + ksum14 += (uint32_t) v14x7; + out[112] = v14x0; + out[113] = v14x1; + out[114] = v14x2; + out[115] = v14x3; + out[116] = v14x4; + out[117] = v14x5; + out[118] = v14x6; + out[119] = v14x7; + w14 += 8; + const int8_t v15x0 = w15[0]; + const int8_t v15x1 = w15[1]; + const int8_t v15x2 = w15[2]; + const int8_t v15x3 = w15[3]; + const int8_t v15x4 = w15[4]; + const int8_t v15x5 = w15[5]; + const int8_t v15x6 = w15[6]; + const int8_t v15x7 = w15[7]; + ksum15 += (uint32_t) v15x0; + ksum15 += (uint32_t) v15x1; + ksum15 += (uint32_t) v15x2; + ksum15 += (uint32_t) v15x3; + ksum15 += (uint32_t) v15x4; + ksum15 += (uint32_t) v15x5; + ksum15 += (uint32_t) v15x6; + ksum15 += (uint32_t) v15x7; + out[120] = v15x0; + out[121] = v15x1; + out[122] = v15x2; + out[123] = v15x3; + out[124] = v15x4; + out[125] = v15x5; + out[126] = v15x6; + out[127] = v15x7; + w15 += 8; + out += 128; + } + + // KC remainder 1..KR-1 + if (k != 0) { + const int8_t v0x0 = 0 < k ? w0[0] : izp; + const int8_t v0x1 = 1 < k ? w0[1] : izp; + const int8_t v0x2 = 2 < k ? w0[2] : izp; + const int8_t v0x3 = 3 < k ? w0[3] : izp; + const int8_t v0x4 = 4 < k ? w0[4] : izp; + const int8_t v0x5 = 5 < k ? w0[5] : izp; + const int8_t v0x6 = 6 < k ? w0[6] : izp; + const int8_t v0x7 = 7 < k ? w0[7] : izp; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + if (0 < k) { + out[0] = v0x0; + } + if (1 < k) { + out[1] = v0x1; + } + if (2 < k) { + out[2] = v0x2; + } + if (3 < k) { + out[3] = v0x3; + } + if (4 < k) { + out[4] = v0x4; + } + if (5 < k) { + out[5] = v0x5; + } + if (6 < k) { + out[6] = v0x6; + } + if (7 < k) { + out[7] = v0x7; + } + w0 += 8; + const int8_t v1x0 = 0 < k ? w1[0] : izp; + const int8_t v1x1 = 1 < k ? w1[1] : izp; + const int8_t v1x2 = 2 < k ? w1[2] : izp; + const int8_t v1x3 = 3 < k ? w1[3] : izp; + const int8_t v1x4 = 4 < k ? w1[4] : izp; + const int8_t v1x5 = 5 < k ? w1[5] : izp; + const int8_t v1x6 = 6 < k ? w1[6] : izp; + const int8_t v1x7 = 7 < k ? w1[7] : izp; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + if (0 < k) { + out[8] = v1x0; + } + if (1 < k) { + out[9] = v1x1; + } + if (2 < k) { + out[10] = v1x2; + } + if (3 < k) { + out[11] = v1x3; + } + if (4 < k) { + out[12] = v1x4; + } + if (5 < k) { + out[13] = v1x5; + } + if (6 < k) { + out[14] = v1x6; + } + if (7 < k) { + out[15] = v1x7; + } + w1 += 8; + const int8_t v2x0 = 0 < k ? w2[0] : izp; + const int8_t v2x1 = 1 < k ? w2[1] : izp; + const int8_t v2x2 = 2 < k ? w2[2] : izp; + const int8_t v2x3 = 3 < k ? w2[3] : izp; + const int8_t v2x4 = 4 < k ? w2[4] : izp; + const int8_t v2x5 = 5 < k ? w2[5] : izp; + const int8_t v2x6 = 6 < k ? w2[6] : izp; + const int8_t v2x7 = 7 < k ? w2[7] : izp; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + if (0 < k) { + out[16] = v2x0; + } + if (1 < k) { + out[17] = v2x1; + } + if (2 < k) { + out[18] = v2x2; + } + if (3 < k) { + out[19] = v2x3; + } + if (4 < k) { + out[20] = v2x4; + } + if (5 < k) { + out[21] = v2x5; + } + if (6 < k) { + out[22] = v2x6; + } + if (7 < k) { + out[23] = v2x7; + } + w2 += 8; + const int8_t v3x0 = 0 < k ? w3[0] : izp; + const int8_t v3x1 = 1 < k ? w3[1] : izp; + const int8_t v3x2 = 2 < k ? w3[2] : izp; + const int8_t v3x3 = 3 < k ? w3[3] : izp; + const int8_t v3x4 = 4 < k ? w3[4] : izp; + const int8_t v3x5 = 5 < k ? w3[5] : izp; + const int8_t v3x6 = 6 < k ? w3[6] : izp; + const int8_t v3x7 = 7 < k ? w3[7] : izp; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + if (0 < k) { + out[24] = v3x0; + } + if (1 < k) { + out[25] = v3x1; + } + if (2 < k) { + out[26] = v3x2; + } + if (3 < k) { + out[27] = v3x3; + } + if (4 < k) { + out[28] = v3x4; + } + if (5 < k) { + out[29] = v3x5; + } + if (6 < k) { + out[30] = v3x6; + } + if (7 < k) { + out[31] = v3x7; + } + w3 += 8; + const int8_t v4x0 = 0 < k ? w4[0] : izp; + const int8_t v4x1 = 1 < k ? w4[1] : izp; + const int8_t v4x2 = 2 < k ? w4[2] : izp; + const int8_t v4x3 = 3 < k ? w4[3] : izp; + const int8_t v4x4 = 4 < k ? w4[4] : izp; + const int8_t v4x5 = 5 < k ? w4[5] : izp; + const int8_t v4x6 = 6 < k ? w4[6] : izp; + const int8_t v4x7 = 7 < k ? w4[7] : izp; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + if (0 < k) { + out[32] = v4x0; + } + if (1 < k) { + out[33] = v4x1; + } + if (2 < k) { + out[34] = v4x2; + } + if (3 < k) { + out[35] = v4x3; + } + if (4 < k) { + out[36] = v4x4; + } + if (5 < k) { + out[37] = v4x5; + } + if (6 < k) { + out[38] = v4x6; + } + if (7 < k) { + out[39] = v4x7; + } + w4 += 8; + const int8_t v5x0 = 0 < k ? w5[0] : izp; + const int8_t v5x1 = 1 < k ? w5[1] : izp; + const int8_t v5x2 = 2 < k ? w5[2] : izp; + const int8_t v5x3 = 3 < k ? w5[3] : izp; + const int8_t v5x4 = 4 < k ? w5[4] : izp; + const int8_t v5x5 = 5 < k ? w5[5] : izp; + const int8_t v5x6 = 6 < k ? w5[6] : izp; + const int8_t v5x7 = 7 < k ? w5[7] : izp; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + if (0 < k) { + out[40] = v5x0; + } + if (1 < k) { + out[41] = v5x1; + } + if (2 < k) { + out[42] = v5x2; + } + if (3 < k) { + out[43] = v5x3; + } + if (4 < k) { + out[44] = v5x4; + } + if (5 < k) { + out[45] = v5x5; + } + if (6 < k) { + out[46] = v5x6; + } + if (7 < k) { + out[47] = v5x7; + } + w5 += 8; + const int8_t v6x0 = 0 < k ? w6[0] : izp; + const int8_t v6x1 = 1 < k ? w6[1] : izp; + const int8_t v6x2 = 2 < k ? w6[2] : izp; + const int8_t v6x3 = 3 < k ? w6[3] : izp; + const int8_t v6x4 = 4 < k ? w6[4] : izp; + const int8_t v6x5 = 5 < k ? w6[5] : izp; + const int8_t v6x6 = 6 < k ? w6[6] : izp; + const int8_t v6x7 = 7 < k ? w6[7] : izp; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + if (0 < k) { + out[48] = v6x0; + } + if (1 < k) { + out[49] = v6x1; + } + if (2 < k) { + out[50] = v6x2; + } + if (3 < k) { + out[51] = v6x3; + } + if (4 < k) { + out[52] = v6x4; + } + if (5 < k) { + out[53] = v6x5; + } + if (6 < k) { + out[54] = v6x6; + } + if (7 < k) { + out[55] = v6x7; + } + w6 += 8; + const int8_t v7x0 = 0 < k ? w7[0] : izp; + const int8_t v7x1 = 1 < k ? w7[1] : izp; + const int8_t v7x2 = 2 < k ? w7[2] : izp; + const int8_t v7x3 = 3 < k ? w7[3] : izp; + const int8_t v7x4 = 4 < k ? w7[4] : izp; + const int8_t v7x5 = 5 < k ? w7[5] : izp; + const int8_t v7x6 = 6 < k ? w7[6] : izp; + const int8_t v7x7 = 7 < k ? w7[7] : izp; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + if (0 < k) { + out[56] = v7x0; + } + if (1 < k) { + out[57] = v7x1; + } + if (2 < k) { + out[58] = v7x2; + } + if (3 < k) { + out[59] = v7x3; + } + if (4 < k) { + out[60] = v7x4; + } + if (5 < k) { + out[61] = v7x5; + } + if (6 < k) { + out[62] = v7x6; + } + if (7 < k) { + out[63] = v7x7; + } + w7 += 8; + const int8_t v8x0 = 0 < k ? w8[0] : izp; + const int8_t v8x1 = 1 < k ? w8[1] : izp; + const int8_t v8x2 = 2 < k ? w8[2] : izp; + const int8_t v8x3 = 3 < k ? w8[3] : izp; + const int8_t v8x4 = 4 < k ? w8[4] : izp; + const int8_t v8x5 = 5 < k ? w8[5] : izp; + const int8_t v8x6 = 6 < k ? w8[6] : izp; + const int8_t v8x7 = 7 < k ? w8[7] : izp; + ksum8 += (uint32_t) v8x0; + ksum8 += (uint32_t) v8x1; + ksum8 += (uint32_t) v8x2; + ksum8 += (uint32_t) v8x3; + ksum8 += (uint32_t) v8x4; + ksum8 += (uint32_t) v8x5; + ksum8 += (uint32_t) v8x6; + ksum8 += (uint32_t) v8x7; + if (0 < k) { + out[64] = v8x0; + } + if (1 < k) { + out[65] = v8x1; + } + if (2 < k) { + out[66] = v8x2; + } + if (3 < k) { + out[67] = v8x3; + } + if (4 < k) { + out[68] = v8x4; + } + if (5 < k) { + out[69] = v8x5; + } + if (6 < k) { + out[70] = v8x6; + } + if (7 < k) { + out[71] = v8x7; + } + w8 += 8; + const int8_t v9x0 = 0 < k ? w9[0] : izp; + const int8_t v9x1 = 1 < k ? w9[1] : izp; + const int8_t v9x2 = 2 < k ? w9[2] : izp; + const int8_t v9x3 = 3 < k ? w9[3] : izp; + const int8_t v9x4 = 4 < k ? w9[4] : izp; + const int8_t v9x5 = 5 < k ? w9[5] : izp; + const int8_t v9x6 = 6 < k ? w9[6] : izp; + const int8_t v9x7 = 7 < k ? w9[7] : izp; + ksum9 += (uint32_t) v9x0; + ksum9 += (uint32_t) v9x1; + ksum9 += (uint32_t) v9x2; + ksum9 += (uint32_t) v9x3; + ksum9 += (uint32_t) v9x4; + ksum9 += (uint32_t) v9x5; + ksum9 += (uint32_t) v9x6; + ksum9 += (uint32_t) v9x7; + if (0 < k) { + out[72] = v9x0; + } + if (1 < k) { + out[73] = v9x1; + } + if (2 < k) { + out[74] = v9x2; + } + if (3 < k) { + out[75] = v9x3; + } + if (4 < k) { + out[76] = v9x4; + } + if (5 < k) { + out[77] = v9x5; + } + if (6 < k) { + out[78] = v9x6; + } + if (7 < k) { + out[79] = v9x7; + } + w9 += 8; + const int8_t v10x0 = 0 < k ? w10[0] : izp; + const int8_t v10x1 = 1 < k ? w10[1] : izp; + const int8_t v10x2 = 2 < k ? w10[2] : izp; + const int8_t v10x3 = 3 < k ? w10[3] : izp; + const int8_t v10x4 = 4 < k ? w10[4] : izp; + const int8_t v10x5 = 5 < k ? w10[5] : izp; + const int8_t v10x6 = 6 < k ? w10[6] : izp; + const int8_t v10x7 = 7 < k ? w10[7] : izp; + ksum10 += (uint32_t) v10x0; + ksum10 += (uint32_t) v10x1; + ksum10 += (uint32_t) v10x2; + ksum10 += (uint32_t) v10x3; + ksum10 += (uint32_t) v10x4; + ksum10 += (uint32_t) v10x5; + ksum10 += (uint32_t) v10x6; + ksum10 += (uint32_t) v10x7; + if (0 < k) { + out[80] = v10x0; + } + if (1 < k) { + out[81] = v10x1; + } + if (2 < k) { + out[82] = v10x2; + } + if (3 < k) { + out[83] = v10x3; + } + if (4 < k) { + out[84] = v10x4; + } + if (5 < k) { + out[85] = v10x5; + } + if (6 < k) { + out[86] = v10x6; + } + if (7 < k) { + out[87] = v10x7; + } + w10 += 8; + const int8_t v11x0 = 0 < k ? w11[0] : izp; + const int8_t v11x1 = 1 < k ? w11[1] : izp; + const int8_t v11x2 = 2 < k ? w11[2] : izp; + const int8_t v11x3 = 3 < k ? w11[3] : izp; + const int8_t v11x4 = 4 < k ? w11[4] : izp; + const int8_t v11x5 = 5 < k ? w11[5] : izp; + const int8_t v11x6 = 6 < k ? w11[6] : izp; + const int8_t v11x7 = 7 < k ? w11[7] : izp; + ksum11 += (uint32_t) v11x0; + ksum11 += (uint32_t) v11x1; + ksum11 += (uint32_t) v11x2; + ksum11 += (uint32_t) v11x3; + ksum11 += (uint32_t) v11x4; + ksum11 += (uint32_t) v11x5; + ksum11 += (uint32_t) v11x6; + ksum11 += (uint32_t) v11x7; + if (0 < k) { + out[88] = v11x0; + } + if (1 < k) { + out[89] = v11x1; + } + if (2 < k) { + out[90] = v11x2; + } + if (3 < k) { + out[91] = v11x3; + } + if (4 < k) { + out[92] = v11x4; + } + if (5 < k) { + out[93] = v11x5; + } + if (6 < k) { + out[94] = v11x6; + } + if (7 < k) { + out[95] = v11x7; + } + w11 += 8; + const int8_t v12x0 = 0 < k ? w12[0] : izp; + const int8_t v12x1 = 1 < k ? w12[1] : izp; + const int8_t v12x2 = 2 < k ? w12[2] : izp; + const int8_t v12x3 = 3 < k ? w12[3] : izp; + const int8_t v12x4 = 4 < k ? w12[4] : izp; + const int8_t v12x5 = 5 < k ? w12[5] : izp; + const int8_t v12x6 = 6 < k ? w12[6] : izp; + const int8_t v12x7 = 7 < k ? w12[7] : izp; + ksum12 += (uint32_t) v12x0; + ksum12 += (uint32_t) v12x1; + ksum12 += (uint32_t) v12x2; + ksum12 += (uint32_t) v12x3; + ksum12 += (uint32_t) v12x4; + ksum12 += (uint32_t) v12x5; + ksum12 += (uint32_t) v12x6; + ksum12 += (uint32_t) v12x7; + if (0 < k) { + out[96] = v12x0; + } + if (1 < k) { + out[97] = v12x1; + } + if (2 < k) { + out[98] = v12x2; + } + if (3 < k) { + out[99] = v12x3; + } + if (4 < k) { + out[100] = v12x4; + } + if (5 < k) { + out[101] = v12x5; + } + if (6 < k) { + out[102] = v12x6; + } + if (7 < k) { + out[103] = v12x7; + } + w12 += 8; + const int8_t v13x0 = 0 < k ? w13[0] : izp; + const int8_t v13x1 = 1 < k ? w13[1] : izp; + const int8_t v13x2 = 2 < k ? w13[2] : izp; + const int8_t v13x3 = 3 < k ? w13[3] : izp; + const int8_t v13x4 = 4 < k ? w13[4] : izp; + const int8_t v13x5 = 5 < k ? w13[5] : izp; + const int8_t v13x6 = 6 < k ? w13[6] : izp; + const int8_t v13x7 = 7 < k ? w13[7] : izp; + ksum13 += (uint32_t) v13x0; + ksum13 += (uint32_t) v13x1; + ksum13 += (uint32_t) v13x2; + ksum13 += (uint32_t) v13x3; + ksum13 += (uint32_t) v13x4; + ksum13 += (uint32_t) v13x5; + ksum13 += (uint32_t) v13x6; + ksum13 += (uint32_t) v13x7; + if (0 < k) { + out[104] = v13x0; + } + if (1 < k) { + out[105] = v13x1; + } + if (2 < k) { + out[106] = v13x2; + } + if (3 < k) { + out[107] = v13x3; + } + if (4 < k) { + out[108] = v13x4; + } + if (5 < k) { + out[109] = v13x5; + } + if (6 < k) { + out[110] = v13x6; + } + if (7 < k) { + out[111] = v13x7; + } + w13 += 8; + const int8_t v14x0 = 0 < k ? w14[0] : izp; + const int8_t v14x1 = 1 < k ? w14[1] : izp; + const int8_t v14x2 = 2 < k ? w14[2] : izp; + const int8_t v14x3 = 3 < k ? w14[3] : izp; + const int8_t v14x4 = 4 < k ? w14[4] : izp; + const int8_t v14x5 = 5 < k ? w14[5] : izp; + const int8_t v14x6 = 6 < k ? w14[6] : izp; + const int8_t v14x7 = 7 < k ? w14[7] : izp; + ksum14 += (uint32_t) v14x0; + ksum14 += (uint32_t) v14x1; + ksum14 += (uint32_t) v14x2; + ksum14 += (uint32_t) v14x3; + ksum14 += (uint32_t) v14x4; + ksum14 += (uint32_t) v14x5; + ksum14 += (uint32_t) v14x6; + ksum14 += (uint32_t) v14x7; + if (0 < k) { + out[112] = v14x0; + } + if (1 < k) { + out[113] = v14x1; + } + if (2 < k) { + out[114] = v14x2; + } + if (3 < k) { + out[115] = v14x3; + } + if (4 < k) { + out[116] = v14x4; + } + if (5 < k) { + out[117] = v14x5; + } + if (6 < k) { + out[118] = v14x6; + } + if (7 < k) { + out[119] = v14x7; + } + w14 += 8; + const int8_t v15x0 = 0 < k ? w15[0] : izp; + const int8_t v15x1 = 1 < k ? w15[1] : izp; + const int8_t v15x2 = 2 < k ? w15[2] : izp; + const int8_t v15x3 = 3 < k ? w15[3] : izp; + const int8_t v15x4 = 4 < k ? w15[4] : izp; + const int8_t v15x5 = 5 < k ? w15[5] : izp; + const int8_t v15x6 = 6 < k ? w15[6] : izp; + const int8_t v15x7 = 7 < k ? w15[7] : izp; + ksum15 += (uint32_t) v15x0; + ksum15 += (uint32_t) v15x1; + ksum15 += (uint32_t) v15x2; + ksum15 += (uint32_t) v15x3; + ksum15 += (uint32_t) v15x4; + ksum15 += (uint32_t) v15x5; + ksum15 += (uint32_t) v15x6; + ksum15 += (uint32_t) v15x7; + if (0 < k) { + out[120] = v15x0; + } + if (1 < k) { + out[121] = v15x1; + } + if (2 < k) { + out[122] = v15x2; + } + if (3 < k) { + out[123] = v15x3; + } + if (4 < k) { + out[124] = v15x4; + } + if (5 < k) { + out[125] = v15x5; + } + if (6 < k) { + out[126] = v15x6; + } + if (7 < k) { + out[127] = v15x7; + } + w15 += 8; + out += 128; + } + + packed_b[0] -= ksum0 * izp; + packed_b[1] -= ksum1 * izp; + packed_b[2] -= ksum2 * izp; + packed_b[3] -= ksum3 * izp; + packed_b[4] -= ksum4 * izp; + packed_b[5] -= ksum5 * izp; + packed_b[6] -= ksum6 * izp; + packed_b[7] -= ksum7 * izp; + packed_b[8] -= ksum8 * izp; + packed_b[9] -= ksum9 * izp; + packed_b[10] -= ksum10 * izp; + packed_b[11] -= ksum11 * izp; + packed_b[12] -= ksum12 * izp; + packed_b[13] -= ksum13 * izp; + packed_b[14] -= ksum14 * izp; + packed_b[15] -= ksum15 * izp; + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w15; + } + + // NC remainder (1..15) + if XNN_UNLIKELY(n != 0) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((int32_t*) out) = *b++; + out += sizeof(int32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((int32_t*) out) = 0; + out += sizeof(int32_t); + } while (--nb != 0); + } + out += (16 - n) * sizeof(int32_t); + + // NR remainder has less than 16 rows so last row is not loaded + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + + uint32_t ksum0 = 0; + uint32_t ksum1 = 0; + uint32_t ksum2 = 0; + uint32_t ksum3 = 0; + uint32_t ksum4 = 0; + uint32_t ksum5 = 0; + uint32_t ksum6 = 0; + uint32_t ksum7 = 0; + uint32_t ksum8 = 0; + uint32_t ksum9 = 0; + uint32_t ksum10 = 0; + uint32_t ksum11 = 0; + uint32_t ksum12 = 0; + uint32_t ksum13 = 0; + uint32_t ksum14 = 0; + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + const int8_t v0x0 = w0[0]; + const int8_t v0x1 = w0[1]; + const int8_t v0x2 = w0[2]; + const int8_t v0x3 = w0[3]; + const int8_t v0x4 = w0[4]; + const int8_t v0x5 = w0[5]; + const int8_t v0x6 = w0[6]; + const int8_t v0x7 = w0[7]; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + out[0] = v0x0; + out[1] = v0x1; + out[2] = v0x2; + out[3] = v0x3; + out[4] = v0x4; + out[5] = v0x5; + out[6] = v0x6; + out[7] = v0x7; + w0 += 8; + const int8_t v1x0 = w1[0]; + const int8_t v1x1 = w1[1]; + const int8_t v1x2 = w1[2]; + const int8_t v1x3 = w1[3]; + const int8_t v1x4 = w1[4]; + const int8_t v1x5 = w1[5]; + const int8_t v1x6 = w1[6]; + const int8_t v1x7 = w1[7]; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + out[8] = v1x0; + out[9] = v1x1; + out[10] = v1x2; + out[11] = v1x3; + out[12] = v1x4; + out[13] = v1x5; + out[14] = v1x6; + out[15] = v1x7; + w1 += 8; + const int8_t v2x0 = w2[0]; + const int8_t v2x1 = w2[1]; + const int8_t v2x2 = w2[2]; + const int8_t v2x3 = w2[3]; + const int8_t v2x4 = w2[4]; + const int8_t v2x5 = w2[5]; + const int8_t v2x6 = w2[6]; + const int8_t v2x7 = w2[7]; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + out[16] = v2x0; + out[17] = v2x1; + out[18] = v2x2; + out[19] = v2x3; + out[20] = v2x4; + out[21] = v2x5; + out[22] = v2x6; + out[23] = v2x7; + w2 += 8; + const int8_t v3x0 = w3[0]; + const int8_t v3x1 = w3[1]; + const int8_t v3x2 = w3[2]; + const int8_t v3x3 = w3[3]; + const int8_t v3x4 = w3[4]; + const int8_t v3x5 = w3[5]; + const int8_t v3x6 = w3[6]; + const int8_t v3x7 = w3[7]; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + out[24] = v3x0; + out[25] = v3x1; + out[26] = v3x2; + out[27] = v3x3; + out[28] = v3x4; + out[29] = v3x5; + out[30] = v3x6; + out[31] = v3x7; + w3 += 8; + const int8_t v4x0 = w4[0]; + const int8_t v4x1 = w4[1]; + const int8_t v4x2 = w4[2]; + const int8_t v4x3 = w4[3]; + const int8_t v4x4 = w4[4]; + const int8_t v4x5 = w4[5]; + const int8_t v4x6 = w4[6]; + const int8_t v4x7 = w4[7]; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + out[32] = v4x0; + out[33] = v4x1; + out[34] = v4x2; + out[35] = v4x3; + out[36] = v4x4; + out[37] = v4x5; + out[38] = v4x6; + out[39] = v4x7; + w4 += 8; + const int8_t v5x0 = w5[0]; + const int8_t v5x1 = w5[1]; + const int8_t v5x2 = w5[2]; + const int8_t v5x3 = w5[3]; + const int8_t v5x4 = w5[4]; + const int8_t v5x5 = w5[5]; + const int8_t v5x6 = w5[6]; + const int8_t v5x7 = w5[7]; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + out[40] = v5x0; + out[41] = v5x1; + out[42] = v5x2; + out[43] = v5x3; + out[44] = v5x4; + out[45] = v5x5; + out[46] = v5x6; + out[47] = v5x7; + w5 += 8; + const int8_t v6x0 = w6[0]; + const int8_t v6x1 = w6[1]; + const int8_t v6x2 = w6[2]; + const int8_t v6x3 = w6[3]; + const int8_t v6x4 = w6[4]; + const int8_t v6x5 = w6[5]; + const int8_t v6x6 = w6[6]; + const int8_t v6x7 = w6[7]; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + out[48] = v6x0; + out[49] = v6x1; + out[50] = v6x2; + out[51] = v6x3; + out[52] = v6x4; + out[53] = v6x5; + out[54] = v6x6; + out[55] = v6x7; + w6 += 8; + const int8_t v7x0 = w7[0]; + const int8_t v7x1 = w7[1]; + const int8_t v7x2 = w7[2]; + const int8_t v7x3 = w7[3]; + const int8_t v7x4 = w7[4]; + const int8_t v7x5 = w7[5]; + const int8_t v7x6 = w7[6]; + const int8_t v7x7 = w7[7]; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + out[56] = v7x0; + out[57] = v7x1; + out[58] = v7x2; + out[59] = v7x3; + out[60] = v7x4; + out[61] = v7x5; + out[62] = v7x6; + out[63] = v7x7; + w7 += 8; + const int8_t v8x0 = w8[0]; + const int8_t v8x1 = w8[1]; + const int8_t v8x2 = w8[2]; + const int8_t v8x3 = w8[3]; + const int8_t v8x4 = w8[4]; + const int8_t v8x5 = w8[5]; + const int8_t v8x6 = w8[6]; + const int8_t v8x7 = w8[7]; + ksum8 += (uint32_t) v8x0; + ksum8 += (uint32_t) v8x1; + ksum8 += (uint32_t) v8x2; + ksum8 += (uint32_t) v8x3; + ksum8 += (uint32_t) v8x4; + ksum8 += (uint32_t) v8x5; + ksum8 += (uint32_t) v8x6; + ksum8 += (uint32_t) v8x7; + out[64] = v8x0; + out[65] = v8x1; + out[66] = v8x2; + out[67] = v8x3; + out[68] = v8x4; + out[69] = v8x5; + out[70] = v8x6; + out[71] = v8x7; + w8 += 8; + const int8_t v9x0 = w9[0]; + const int8_t v9x1 = w9[1]; + const int8_t v9x2 = w9[2]; + const int8_t v9x3 = w9[3]; + const int8_t v9x4 = w9[4]; + const int8_t v9x5 = w9[5]; + const int8_t v9x6 = w9[6]; + const int8_t v9x7 = w9[7]; + ksum9 += (uint32_t) v9x0; + ksum9 += (uint32_t) v9x1; + ksum9 += (uint32_t) v9x2; + ksum9 += (uint32_t) v9x3; + ksum9 += (uint32_t) v9x4; + ksum9 += (uint32_t) v9x5; + ksum9 += (uint32_t) v9x6; + ksum9 += (uint32_t) v9x7; + out[72] = v9x0; + out[73] = v9x1; + out[74] = v9x2; + out[75] = v9x3; + out[76] = v9x4; + out[77] = v9x5; + out[78] = v9x6; + out[79] = v9x7; + w9 += 8; + const int8_t v10x0 = w10[0]; + const int8_t v10x1 = w10[1]; + const int8_t v10x2 = w10[2]; + const int8_t v10x3 = w10[3]; + const int8_t v10x4 = w10[4]; + const int8_t v10x5 = w10[5]; + const int8_t v10x6 = w10[6]; + const int8_t v10x7 = w10[7]; + ksum10 += (uint32_t) v10x0; + ksum10 += (uint32_t) v10x1; + ksum10 += (uint32_t) v10x2; + ksum10 += (uint32_t) v10x3; + ksum10 += (uint32_t) v10x4; + ksum10 += (uint32_t) v10x5; + ksum10 += (uint32_t) v10x6; + ksum10 += (uint32_t) v10x7; + out[80] = v10x0; + out[81] = v10x1; + out[82] = v10x2; + out[83] = v10x3; + out[84] = v10x4; + out[85] = v10x5; + out[86] = v10x6; + out[87] = v10x7; + w10 += 8; + const int8_t v11x0 = w11[0]; + const int8_t v11x1 = w11[1]; + const int8_t v11x2 = w11[2]; + const int8_t v11x3 = w11[3]; + const int8_t v11x4 = w11[4]; + const int8_t v11x5 = w11[5]; + const int8_t v11x6 = w11[6]; + const int8_t v11x7 = w11[7]; + ksum11 += (uint32_t) v11x0; + ksum11 += (uint32_t) v11x1; + ksum11 += (uint32_t) v11x2; + ksum11 += (uint32_t) v11x3; + ksum11 += (uint32_t) v11x4; + ksum11 += (uint32_t) v11x5; + ksum11 += (uint32_t) v11x6; + ksum11 += (uint32_t) v11x7; + out[88] = v11x0; + out[89] = v11x1; + out[90] = v11x2; + out[91] = v11x3; + out[92] = v11x4; + out[93] = v11x5; + out[94] = v11x6; + out[95] = v11x7; + w11 += 8; + const int8_t v12x0 = w12[0]; + const int8_t v12x1 = w12[1]; + const int8_t v12x2 = w12[2]; + const int8_t v12x3 = w12[3]; + const int8_t v12x4 = w12[4]; + const int8_t v12x5 = w12[5]; + const int8_t v12x6 = w12[6]; + const int8_t v12x7 = w12[7]; + ksum12 += (uint32_t) v12x0; + ksum12 += (uint32_t) v12x1; + ksum12 += (uint32_t) v12x2; + ksum12 += (uint32_t) v12x3; + ksum12 += (uint32_t) v12x4; + ksum12 += (uint32_t) v12x5; + ksum12 += (uint32_t) v12x6; + ksum12 += (uint32_t) v12x7; + out[96] = v12x0; + out[97] = v12x1; + out[98] = v12x2; + out[99] = v12x3; + out[100] = v12x4; + out[101] = v12x5; + out[102] = v12x6; + out[103] = v12x7; + w12 += 8; + const int8_t v13x0 = w13[0]; + const int8_t v13x1 = w13[1]; + const int8_t v13x2 = w13[2]; + const int8_t v13x3 = w13[3]; + const int8_t v13x4 = w13[4]; + const int8_t v13x5 = w13[5]; + const int8_t v13x6 = w13[6]; + const int8_t v13x7 = w13[7]; + ksum13 += (uint32_t) v13x0; + ksum13 += (uint32_t) v13x1; + ksum13 += (uint32_t) v13x2; + ksum13 += (uint32_t) v13x3; + ksum13 += (uint32_t) v13x4; + ksum13 += (uint32_t) v13x5; + ksum13 += (uint32_t) v13x6; + ksum13 += (uint32_t) v13x7; + out[104] = v13x0; + out[105] = v13x1; + out[106] = v13x2; + out[107] = v13x3; + out[108] = v13x4; + out[109] = v13x5; + out[110] = v13x6; + out[111] = v13x7; + w13 += 8; + const int8_t v14x0 = w14[0]; + const int8_t v14x1 = w14[1]; + const int8_t v14x2 = w14[2]; + const int8_t v14x3 = w14[3]; + const int8_t v14x4 = w14[4]; + const int8_t v14x5 = w14[5]; + const int8_t v14x6 = w14[6]; + const int8_t v14x7 = w14[7]; + ksum14 += (uint32_t) v14x0; + ksum14 += (uint32_t) v14x1; + ksum14 += (uint32_t) v14x2; + ksum14 += (uint32_t) v14x3; + ksum14 += (uint32_t) v14x4; + ksum14 += (uint32_t) v14x5; + ksum14 += (uint32_t) v14x6; + ksum14 += (uint32_t) v14x7; + out[112] = v14x0; + out[113] = v14x1; + out[114] = v14x2; + out[115] = v14x3; + out[116] = v14x4; + out[117] = v14x5; + out[118] = v14x6; + out[119] = v14x7; + w14 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + const int8_t v0x0 = 0 < k ? w0[0] : izp; + const int8_t v0x1 = 1 < k ? w0[1] : izp; + const int8_t v0x2 = 2 < k ? w0[2] : izp; + const int8_t v0x3 = 3 < k ? w0[3] : izp; + const int8_t v0x4 = 4 < k ? w0[4] : izp; + const int8_t v0x5 = 5 < k ? w0[5] : izp; + const int8_t v0x6 = 6 < k ? w0[6] : izp; + const int8_t v0x7 = 7 < k ? w0[7] : izp; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + if (0 < k) { + out[0] = v0x0; + } + if (1 < k) { + out[1] = v0x1; + } + if (2 < k) { + out[2] = v0x2; + } + if (3 < k) { + out[3] = v0x3; + } + if (4 < k) { + out[4] = v0x4; + } + if (5 < k) { + out[5] = v0x5; + } + if (6 < k) { + out[6] = v0x6; + } + if (7 < k) { + out[7] = v0x7; + } + w0 += 8; + const int8_t v1x0 = 0 < k ? w1[0] : izp; + const int8_t v1x1 = 1 < k ? w1[1] : izp; + const int8_t v1x2 = 2 < k ? w1[2] : izp; + const int8_t v1x3 = 3 < k ? w1[3] : izp; + const int8_t v1x4 = 4 < k ? w1[4] : izp; + const int8_t v1x5 = 5 < k ? w1[5] : izp; + const int8_t v1x6 = 6 < k ? w1[6] : izp; + const int8_t v1x7 = 7 < k ? w1[7] : izp; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + if (0 < k) { + out[8] = v1x0; + } + if (1 < k) { + out[9] = v1x1; + } + if (2 < k) { + out[10] = v1x2; + } + if (3 < k) { + out[11] = v1x3; + } + if (4 < k) { + out[12] = v1x4; + } + if (5 < k) { + out[13] = v1x5; + } + if (6 < k) { + out[14] = v1x6; + } + if (7 < k) { + out[15] = v1x7; + } + w1 += 8; + const int8_t v2x0 = 0 < k ? w2[0] : izp; + const int8_t v2x1 = 1 < k ? w2[1] : izp; + const int8_t v2x2 = 2 < k ? w2[2] : izp; + const int8_t v2x3 = 3 < k ? w2[3] : izp; + const int8_t v2x4 = 4 < k ? w2[4] : izp; + const int8_t v2x5 = 5 < k ? w2[5] : izp; + const int8_t v2x6 = 6 < k ? w2[6] : izp; + const int8_t v2x7 = 7 < k ? w2[7] : izp; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + if (0 < k) { + out[16] = v2x0; + } + if (1 < k) { + out[17] = v2x1; + } + if (2 < k) { + out[18] = v2x2; + } + if (3 < k) { + out[19] = v2x3; + } + if (4 < k) { + out[20] = v2x4; + } + if (5 < k) { + out[21] = v2x5; + } + if (6 < k) { + out[22] = v2x6; + } + if (7 < k) { + out[23] = v2x7; + } + w2 += 8; + const int8_t v3x0 = 0 < k ? w3[0] : izp; + const int8_t v3x1 = 1 < k ? w3[1] : izp; + const int8_t v3x2 = 2 < k ? w3[2] : izp; + const int8_t v3x3 = 3 < k ? w3[3] : izp; + const int8_t v3x4 = 4 < k ? w3[4] : izp; + const int8_t v3x5 = 5 < k ? w3[5] : izp; + const int8_t v3x6 = 6 < k ? w3[6] : izp; + const int8_t v3x7 = 7 < k ? w3[7] : izp; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + if (0 < k) { + out[24] = v3x0; + } + if (1 < k) { + out[25] = v3x1; + } + if (2 < k) { + out[26] = v3x2; + } + if (3 < k) { + out[27] = v3x3; + } + if (4 < k) { + out[28] = v3x4; + } + if (5 < k) { + out[29] = v3x5; + } + if (6 < k) { + out[30] = v3x6; + } + if (7 < k) { + out[31] = v3x7; + } + w3 += 8; + const int8_t v4x0 = 0 < k ? w4[0] : izp; + const int8_t v4x1 = 1 < k ? w4[1] : izp; + const int8_t v4x2 = 2 < k ? w4[2] : izp; + const int8_t v4x3 = 3 < k ? w4[3] : izp; + const int8_t v4x4 = 4 < k ? w4[4] : izp; + const int8_t v4x5 = 5 < k ? w4[5] : izp; + const int8_t v4x6 = 6 < k ? w4[6] : izp; + const int8_t v4x7 = 7 < k ? w4[7] : izp; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + if (0 < k) { + out[32] = v4x0; + } + if (1 < k) { + out[33] = v4x1; + } + if (2 < k) { + out[34] = v4x2; + } + if (3 < k) { + out[35] = v4x3; + } + if (4 < k) { + out[36] = v4x4; + } + if (5 < k) { + out[37] = v4x5; + } + if (6 < k) { + out[38] = v4x6; + } + if (7 < k) { + out[39] = v4x7; + } + w4 += 8; + const int8_t v5x0 = 0 < k ? w5[0] : izp; + const int8_t v5x1 = 1 < k ? w5[1] : izp; + const int8_t v5x2 = 2 < k ? w5[2] : izp; + const int8_t v5x3 = 3 < k ? w5[3] : izp; + const int8_t v5x4 = 4 < k ? w5[4] : izp; + const int8_t v5x5 = 5 < k ? w5[5] : izp; + const int8_t v5x6 = 6 < k ? w5[6] : izp; + const int8_t v5x7 = 7 < k ? w5[7] : izp; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + if (0 < k) { + out[40] = v5x0; + } + if (1 < k) { + out[41] = v5x1; + } + if (2 < k) { + out[42] = v5x2; + } + if (3 < k) { + out[43] = v5x3; + } + if (4 < k) { + out[44] = v5x4; + } + if (5 < k) { + out[45] = v5x5; + } + if (6 < k) { + out[46] = v5x6; + } + if (7 < k) { + out[47] = v5x7; + } + w5 += 8; + const int8_t v6x0 = 0 < k ? w6[0] : izp; + const int8_t v6x1 = 1 < k ? w6[1] : izp; + const int8_t v6x2 = 2 < k ? w6[2] : izp; + const int8_t v6x3 = 3 < k ? w6[3] : izp; + const int8_t v6x4 = 4 < k ? w6[4] : izp; + const int8_t v6x5 = 5 < k ? w6[5] : izp; + const int8_t v6x6 = 6 < k ? w6[6] : izp; + const int8_t v6x7 = 7 < k ? w6[7] : izp; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + if (0 < k) { + out[48] = v6x0; + } + if (1 < k) { + out[49] = v6x1; + } + if (2 < k) { + out[50] = v6x2; + } + if (3 < k) { + out[51] = v6x3; + } + if (4 < k) { + out[52] = v6x4; + } + if (5 < k) { + out[53] = v6x5; + } + if (6 < k) { + out[54] = v6x6; + } + if (7 < k) { + out[55] = v6x7; + } + w6 += 8; + const int8_t v7x0 = 0 < k ? w7[0] : izp; + const int8_t v7x1 = 1 < k ? w7[1] : izp; + const int8_t v7x2 = 2 < k ? w7[2] : izp; + const int8_t v7x3 = 3 < k ? w7[3] : izp; + const int8_t v7x4 = 4 < k ? w7[4] : izp; + const int8_t v7x5 = 5 < k ? w7[5] : izp; + const int8_t v7x6 = 6 < k ? w7[6] : izp; + const int8_t v7x7 = 7 < k ? w7[7] : izp; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + if (0 < k) { + out[56] = v7x0; + } + if (1 < k) { + out[57] = v7x1; + } + if (2 < k) { + out[58] = v7x2; + } + if (3 < k) { + out[59] = v7x3; + } + if (4 < k) { + out[60] = v7x4; + } + if (5 < k) { + out[61] = v7x5; + } + if (6 < k) { + out[62] = v7x6; + } + if (7 < k) { + out[63] = v7x7; + } + w7 += 8; + const int8_t v8x0 = 0 < k ? w8[0] : izp; + const int8_t v8x1 = 1 < k ? w8[1] : izp; + const int8_t v8x2 = 2 < k ? w8[2] : izp; + const int8_t v8x3 = 3 < k ? w8[3] : izp; + const int8_t v8x4 = 4 < k ? w8[4] : izp; + const int8_t v8x5 = 5 < k ? w8[5] : izp; + const int8_t v8x6 = 6 < k ? w8[6] : izp; + const int8_t v8x7 = 7 < k ? w8[7] : izp; + ksum8 += (uint32_t) v8x0; + ksum8 += (uint32_t) v8x1; + ksum8 += (uint32_t) v8x2; + ksum8 += (uint32_t) v8x3; + ksum8 += (uint32_t) v8x4; + ksum8 += (uint32_t) v8x5; + ksum8 += (uint32_t) v8x6; + ksum8 += (uint32_t) v8x7; + if (0 < k) { + out[64] = v8x0; + } + if (1 < k) { + out[65] = v8x1; + } + if (2 < k) { + out[66] = v8x2; + } + if (3 < k) { + out[67] = v8x3; + } + if (4 < k) { + out[68] = v8x4; + } + if (5 < k) { + out[69] = v8x5; + } + if (6 < k) { + out[70] = v8x6; + } + if (7 < k) { + out[71] = v8x7; + } + w8 += 8; + const int8_t v9x0 = 0 < k ? w9[0] : izp; + const int8_t v9x1 = 1 < k ? w9[1] : izp; + const int8_t v9x2 = 2 < k ? w9[2] : izp; + const int8_t v9x3 = 3 < k ? w9[3] : izp; + const int8_t v9x4 = 4 < k ? w9[4] : izp; + const int8_t v9x5 = 5 < k ? w9[5] : izp; + const int8_t v9x6 = 6 < k ? w9[6] : izp; + const int8_t v9x7 = 7 < k ? w9[7] : izp; + ksum9 += (uint32_t) v9x0; + ksum9 += (uint32_t) v9x1; + ksum9 += (uint32_t) v9x2; + ksum9 += (uint32_t) v9x3; + ksum9 += (uint32_t) v9x4; + ksum9 += (uint32_t) v9x5; + ksum9 += (uint32_t) v9x6; + ksum9 += (uint32_t) v9x7; + if (0 < k) { + out[72] = v9x0; + } + if (1 < k) { + out[73] = v9x1; + } + if (2 < k) { + out[74] = v9x2; + } + if (3 < k) { + out[75] = v9x3; + } + if (4 < k) { + out[76] = v9x4; + } + if (5 < k) { + out[77] = v9x5; + } + if (6 < k) { + out[78] = v9x6; + } + if (7 < k) { + out[79] = v9x7; + } + w9 += 8; + const int8_t v10x0 = 0 < k ? w10[0] : izp; + const int8_t v10x1 = 1 < k ? w10[1] : izp; + const int8_t v10x2 = 2 < k ? w10[2] : izp; + const int8_t v10x3 = 3 < k ? w10[3] : izp; + const int8_t v10x4 = 4 < k ? w10[4] : izp; + const int8_t v10x5 = 5 < k ? w10[5] : izp; + const int8_t v10x6 = 6 < k ? w10[6] : izp; + const int8_t v10x7 = 7 < k ? w10[7] : izp; + ksum10 += (uint32_t) v10x0; + ksum10 += (uint32_t) v10x1; + ksum10 += (uint32_t) v10x2; + ksum10 += (uint32_t) v10x3; + ksum10 += (uint32_t) v10x4; + ksum10 += (uint32_t) v10x5; + ksum10 += (uint32_t) v10x6; + ksum10 += (uint32_t) v10x7; + if (0 < k) { + out[80] = v10x0; + } + if (1 < k) { + out[81] = v10x1; + } + if (2 < k) { + out[82] = v10x2; + } + if (3 < k) { + out[83] = v10x3; + } + if (4 < k) { + out[84] = v10x4; + } + if (5 < k) { + out[85] = v10x5; + } + if (6 < k) { + out[86] = v10x6; + } + if (7 < k) { + out[87] = v10x7; + } + w10 += 8; + const int8_t v11x0 = 0 < k ? w11[0] : izp; + const int8_t v11x1 = 1 < k ? w11[1] : izp; + const int8_t v11x2 = 2 < k ? w11[2] : izp; + const int8_t v11x3 = 3 < k ? w11[3] : izp; + const int8_t v11x4 = 4 < k ? w11[4] : izp; + const int8_t v11x5 = 5 < k ? w11[5] : izp; + const int8_t v11x6 = 6 < k ? w11[6] : izp; + const int8_t v11x7 = 7 < k ? w11[7] : izp; + ksum11 += (uint32_t) v11x0; + ksum11 += (uint32_t) v11x1; + ksum11 += (uint32_t) v11x2; + ksum11 += (uint32_t) v11x3; + ksum11 += (uint32_t) v11x4; + ksum11 += (uint32_t) v11x5; + ksum11 += (uint32_t) v11x6; + ksum11 += (uint32_t) v11x7; + if (0 < k) { + out[88] = v11x0; + } + if (1 < k) { + out[89] = v11x1; + } + if (2 < k) { + out[90] = v11x2; + } + if (3 < k) { + out[91] = v11x3; + } + if (4 < k) { + out[92] = v11x4; + } + if (5 < k) { + out[93] = v11x5; + } + if (6 < k) { + out[94] = v11x6; + } + if (7 < k) { + out[95] = v11x7; + } + w11 += 8; + const int8_t v12x0 = 0 < k ? w12[0] : izp; + const int8_t v12x1 = 1 < k ? w12[1] : izp; + const int8_t v12x2 = 2 < k ? w12[2] : izp; + const int8_t v12x3 = 3 < k ? w12[3] : izp; + const int8_t v12x4 = 4 < k ? w12[4] : izp; + const int8_t v12x5 = 5 < k ? w12[5] : izp; + const int8_t v12x6 = 6 < k ? w12[6] : izp; + const int8_t v12x7 = 7 < k ? w12[7] : izp; + ksum12 += (uint32_t) v12x0; + ksum12 += (uint32_t) v12x1; + ksum12 += (uint32_t) v12x2; + ksum12 += (uint32_t) v12x3; + ksum12 += (uint32_t) v12x4; + ksum12 += (uint32_t) v12x5; + ksum12 += (uint32_t) v12x6; + ksum12 += (uint32_t) v12x7; + if (0 < k) { + out[96] = v12x0; + } + if (1 < k) { + out[97] = v12x1; + } + if (2 < k) { + out[98] = v12x2; + } + if (3 < k) { + out[99] = v12x3; + } + if (4 < k) { + out[100] = v12x4; + } + if (5 < k) { + out[101] = v12x5; + } + if (6 < k) { + out[102] = v12x6; + } + if (7 < k) { + out[103] = v12x7; + } + w12 += 8; + const int8_t v13x0 = 0 < k ? w13[0] : izp; + const int8_t v13x1 = 1 < k ? w13[1] : izp; + const int8_t v13x2 = 2 < k ? w13[2] : izp; + const int8_t v13x3 = 3 < k ? w13[3] : izp; + const int8_t v13x4 = 4 < k ? w13[4] : izp; + const int8_t v13x5 = 5 < k ? w13[5] : izp; + const int8_t v13x6 = 6 < k ? w13[6] : izp; + const int8_t v13x7 = 7 < k ? w13[7] : izp; + ksum13 += (uint32_t) v13x0; + ksum13 += (uint32_t) v13x1; + ksum13 += (uint32_t) v13x2; + ksum13 += (uint32_t) v13x3; + ksum13 += (uint32_t) v13x4; + ksum13 += (uint32_t) v13x5; + ksum13 += (uint32_t) v13x6; + ksum13 += (uint32_t) v13x7; + if (0 < k) { + out[104] = v13x0; + } + if (1 < k) { + out[105] = v13x1; + } + if (2 < k) { + out[106] = v13x2; + } + if (3 < k) { + out[107] = v13x3; + } + if (4 < k) { + out[108] = v13x4; + } + if (5 < k) { + out[109] = v13x5; + } + if (6 < k) { + out[110] = v13x6; + } + if (7 < k) { + out[111] = v13x7; + } + w13 += 8; + const int8_t v14x0 = 0 < k ? w14[0] : izp; + const int8_t v14x1 = 1 < k ? w14[1] : izp; + const int8_t v14x2 = 2 < k ? w14[2] : izp; + const int8_t v14x3 = 3 < k ? w14[3] : izp; + const int8_t v14x4 = 4 < k ? w14[4] : izp; + const int8_t v14x5 = 5 < k ? w14[5] : izp; + const int8_t v14x6 = 6 < k ? w14[6] : izp; + const int8_t v14x7 = 7 < k ? w14[7] : izp; + ksum14 += (uint32_t) v14x0; + ksum14 += (uint32_t) v14x1; + ksum14 += (uint32_t) v14x2; + ksum14 += (uint32_t) v14x3; + ksum14 += (uint32_t) v14x4; + ksum14 += (uint32_t) v14x5; + ksum14 += (uint32_t) v14x6; + ksum14 += (uint32_t) v14x7; + if (0 < k) { + out[112] = v14x0; + } + if (1 < k) { + out[113] = v14x1; + } + if (2 < k) { + out[114] = v14x2; + } + if (3 < k) { + out[115] = v14x3; + } + if (4 < k) { + out[116] = v14x4; + } + if (5 < k) { + out[117] = v14x5; + } + if (6 < k) { + out[118] = v14x6; + } + if (7 < k) { + out[119] = v14x7; + } + w14 += 8; + out += 128; + } + + packed_b[0] -= ksum0 * izp; + packed_b[1] -= ksum1 * izp; + packed_b[2] -= ksum2 * izp; + packed_b[3] -= ksum3 * izp; + packed_b[4] -= ksum4 * izp; + packed_b[5] -= ksum5 * izp; + packed_b[6] -= ksum6 * izp; + packed_b[7] -= ksum7 * izp; + packed_b[8] -= ksum8 * izp; + packed_b[9] -= ksum9 * izp; + packed_b[10] -= ksum10 * izp; + packed_b[11] -= ksum11 * izp; + packed_b[12] -= ksum12 * izp; + packed_b[13] -= ksum13 * izp; + packed_b[14] -= ksum14 * izp; + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c new file mode 100644 index 00000000000..45c63d2b268 --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c @@ -0,0 +1,444 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnniint8.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" + +void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnniint8( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + +// TODO: immintrin.h only provide _mm256_insert_epi64 for __x86_64__ +#if defined(__x86_64__) + int8_t* out = (int8_t*) packed_weights; + const uint32_t* b = (const uint32_t*) bias; + const int8_t izp = params ? ((const struct xnn_qs8_packw_params*) params)->input_zero_point : 0; + __m256i vzeropoint = _mm256_set1_epi8(izp); + + do { + // NC main loop multiple of 8 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 8; n -= 8) { + __m256i vacc0124x8 = _mm256_setzero_si256(); + __m256i vacc4567x8 = _mm256_setzero_si256(); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb = _mm256_loadu_si256((const __m256i*) b); + _mm256_storeu_si256((__m256i*) out, vb); + b += 8; + } else { + _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + } + out += 8 * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0123x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w0)); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w1, 1); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w2, 2); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w3, 3); + + __m256i v4567x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w4)); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w5, 1); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w6, 2); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w7, 3); + + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder 1..KR-1 + if (k != 0) { + __m256i v0123x8 = vzeropoint; + __m256i v4567x8 = vzeropoint; + + if (k & 4) { + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); + + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + } + if (k & 2) { + if (k & 4) { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + } else { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + } + else { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + } + + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + + out += 64; + } + + __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); + vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); + _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w7; + } + + // NC remainder (1..7) + if XNN_UNLIKELY(n != 0) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((uint32_t*) out) = *b++; + out += sizeof(uint32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((uint32_t*) out) = 0; + out += sizeof(uint32_t); + } while (--nb != 0); + } + out += (8 - n) * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + + __m256i vacc0124x8 = _mm256_setzero_si256(); + __m256i vacc4567x8 = _mm256_setzero_si256(); + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0123x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w0)); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w1, 1); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w2, 2); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w3, 3); + + __m256i v4567x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w4)); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w5, 1); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w6, 2); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w7, 3); + + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + __m256i v0123x8 = vzeropoint; + __m256i v4567x8 = vzeropoint; + + if (k & 4) { + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); + + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + } + if (k & 2) { + if (k & 4) { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + } else { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + } + else { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + } + + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + + out += 64; + } + + __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); + vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); + _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + weights += nc * kc; + } while (--g != 0); +#endif // defined(__x86_64__) +} diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c new file mode 100644 index 00000000000..d954e4de632 --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c @@ -0,0 +1,1175 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-scalar.c.in +// Generator: tools/xngen +// +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include "xnnpack/packw.h" + +void xnn_qs8_packw_gemm_goi_ukernel_x8c8__scalar( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + const uint32_t izp = params ? (uint32_t) ((const struct xnn_qs8_packw_params*) params)->input_zero_point : 0; + + do { + // NC main loop multiple of 8 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 8; n -= 8) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + ((int32_t*) out)[0] = b[0]; + ((int32_t*) out)[1] = b[1]; + ((int32_t*) out)[2] = b[2]; + ((int32_t*) out)[3] = b[3]; + ((int32_t*) out)[4] = b[4]; + ((int32_t*) out)[5] = b[5]; + ((int32_t*) out)[6] = b[6]; + ((int32_t*) out)[7] = b[7]; + b += 8; + } else { + ((int32_t*) out)[0] = 0; + ((int32_t*) out)[1] = 0; + ((int32_t*) out)[2] = 0; + ((int32_t*) out)[3] = 0; + ((int32_t*) out)[4] = 0; + ((int32_t*) out)[5] = 0; + ((int32_t*) out)[6] = 0; + ((int32_t*) out)[7] = 0; + } + out += 8 * sizeof(int32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + uint32_t ksum0 = 0; + uint32_t ksum1 = 0; + uint32_t ksum2 = 0; + uint32_t ksum3 = 0; + uint32_t ksum4 = 0; + uint32_t ksum5 = 0; + uint32_t ksum6 = 0; + uint32_t ksum7 = 0; + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + const int8_t v0x0 = w0[0]; + const int8_t v0x1 = w0[1]; + const int8_t v0x2 = w0[2]; + const int8_t v0x3 = w0[3]; + const int8_t v0x4 = w0[4]; + const int8_t v0x5 = w0[5]; + const int8_t v0x6 = w0[6]; + const int8_t v0x7 = w0[7]; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + out[0] = v0x0; + out[1] = v0x1; + out[2] = v0x2; + out[3] = v0x3; + out[4] = v0x4; + out[5] = v0x5; + out[6] = v0x6; + out[7] = v0x7; + w0 += 8; + const int8_t v1x0 = w1[0]; + const int8_t v1x1 = w1[1]; + const int8_t v1x2 = w1[2]; + const int8_t v1x3 = w1[3]; + const int8_t v1x4 = w1[4]; + const int8_t v1x5 = w1[5]; + const int8_t v1x6 = w1[6]; + const int8_t v1x7 = w1[7]; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + out[8] = v1x0; + out[9] = v1x1; + out[10] = v1x2; + out[11] = v1x3; + out[12] = v1x4; + out[13] = v1x5; + out[14] = v1x6; + out[15] = v1x7; + w1 += 8; + const int8_t v2x0 = w2[0]; + const int8_t v2x1 = w2[1]; + const int8_t v2x2 = w2[2]; + const int8_t v2x3 = w2[3]; + const int8_t v2x4 = w2[4]; + const int8_t v2x5 = w2[5]; + const int8_t v2x6 = w2[6]; + const int8_t v2x7 = w2[7]; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + out[16] = v2x0; + out[17] = v2x1; + out[18] = v2x2; + out[19] = v2x3; + out[20] = v2x4; + out[21] = v2x5; + out[22] = v2x6; + out[23] = v2x7; + w2 += 8; + const int8_t v3x0 = w3[0]; + const int8_t v3x1 = w3[1]; + const int8_t v3x2 = w3[2]; + const int8_t v3x3 = w3[3]; + const int8_t v3x4 = w3[4]; + const int8_t v3x5 = w3[5]; + const int8_t v3x6 = w3[6]; + const int8_t v3x7 = w3[7]; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + out[24] = v3x0; + out[25] = v3x1; + out[26] = v3x2; + out[27] = v3x3; + out[28] = v3x4; + out[29] = v3x5; + out[30] = v3x6; + out[31] = v3x7; + w3 += 8; + const int8_t v4x0 = w4[0]; + const int8_t v4x1 = w4[1]; + const int8_t v4x2 = w4[2]; + const int8_t v4x3 = w4[3]; + const int8_t v4x4 = w4[4]; + const int8_t v4x5 = w4[5]; + const int8_t v4x6 = w4[6]; + const int8_t v4x7 = w4[7]; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + out[32] = v4x0; + out[33] = v4x1; + out[34] = v4x2; + out[35] = v4x3; + out[36] = v4x4; + out[37] = v4x5; + out[38] = v4x6; + out[39] = v4x7; + w4 += 8; + const int8_t v5x0 = w5[0]; + const int8_t v5x1 = w5[1]; + const int8_t v5x2 = w5[2]; + const int8_t v5x3 = w5[3]; + const int8_t v5x4 = w5[4]; + const int8_t v5x5 = w5[5]; + const int8_t v5x6 = w5[6]; + const int8_t v5x7 = w5[7]; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + out[40] = v5x0; + out[41] = v5x1; + out[42] = v5x2; + out[43] = v5x3; + out[44] = v5x4; + out[45] = v5x5; + out[46] = v5x6; + out[47] = v5x7; + w5 += 8; + const int8_t v6x0 = w6[0]; + const int8_t v6x1 = w6[1]; + const int8_t v6x2 = w6[2]; + const int8_t v6x3 = w6[3]; + const int8_t v6x4 = w6[4]; + const int8_t v6x5 = w6[5]; + const int8_t v6x6 = w6[6]; + const int8_t v6x7 = w6[7]; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + out[48] = v6x0; + out[49] = v6x1; + out[50] = v6x2; + out[51] = v6x3; + out[52] = v6x4; + out[53] = v6x5; + out[54] = v6x6; + out[55] = v6x7; + w6 += 8; + const int8_t v7x0 = w7[0]; + const int8_t v7x1 = w7[1]; + const int8_t v7x2 = w7[2]; + const int8_t v7x3 = w7[3]; + const int8_t v7x4 = w7[4]; + const int8_t v7x5 = w7[5]; + const int8_t v7x6 = w7[6]; + const int8_t v7x7 = w7[7]; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + out[56] = v7x0; + out[57] = v7x1; + out[58] = v7x2; + out[59] = v7x3; + out[60] = v7x4; + out[61] = v7x5; + out[62] = v7x6; + out[63] = v7x7; + w7 += 8; + out += 64; + } + + // KC remainder 1..KR-1 + if (k != 0) { + const int8_t v0x0 = 0 < k ? w0[0] : izp; + const int8_t v0x1 = 1 < k ? w0[1] : izp; + const int8_t v0x2 = 2 < k ? w0[2] : izp; + const int8_t v0x3 = 3 < k ? w0[3] : izp; + const int8_t v0x4 = 4 < k ? w0[4] : izp; + const int8_t v0x5 = 5 < k ? w0[5] : izp; + const int8_t v0x6 = 6 < k ? w0[6] : izp; + const int8_t v0x7 = 7 < k ? w0[7] : izp; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + if (0 < k) { + out[0] = v0x0; + } + if (1 < k) { + out[1] = v0x1; + } + if (2 < k) { + out[2] = v0x2; + } + if (3 < k) { + out[3] = v0x3; + } + if (4 < k) { + out[4] = v0x4; + } + if (5 < k) { + out[5] = v0x5; + } + if (6 < k) { + out[6] = v0x6; + } + if (7 < k) { + out[7] = v0x7; + } + w0 += 8; + const int8_t v1x0 = 0 < k ? w1[0] : izp; + const int8_t v1x1 = 1 < k ? w1[1] : izp; + const int8_t v1x2 = 2 < k ? w1[2] : izp; + const int8_t v1x3 = 3 < k ? w1[3] : izp; + const int8_t v1x4 = 4 < k ? w1[4] : izp; + const int8_t v1x5 = 5 < k ? w1[5] : izp; + const int8_t v1x6 = 6 < k ? w1[6] : izp; + const int8_t v1x7 = 7 < k ? w1[7] : izp; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + if (0 < k) { + out[8] = v1x0; + } + if (1 < k) { + out[9] = v1x1; + } + if (2 < k) { + out[10] = v1x2; + } + if (3 < k) { + out[11] = v1x3; + } + if (4 < k) { + out[12] = v1x4; + } + if (5 < k) { + out[13] = v1x5; + } + if (6 < k) { + out[14] = v1x6; + } + if (7 < k) { + out[15] = v1x7; + } + w1 += 8; + const int8_t v2x0 = 0 < k ? w2[0] : izp; + const int8_t v2x1 = 1 < k ? w2[1] : izp; + const int8_t v2x2 = 2 < k ? w2[2] : izp; + const int8_t v2x3 = 3 < k ? w2[3] : izp; + const int8_t v2x4 = 4 < k ? w2[4] : izp; + const int8_t v2x5 = 5 < k ? w2[5] : izp; + const int8_t v2x6 = 6 < k ? w2[6] : izp; + const int8_t v2x7 = 7 < k ? w2[7] : izp; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + if (0 < k) { + out[16] = v2x0; + } + if (1 < k) { + out[17] = v2x1; + } + if (2 < k) { + out[18] = v2x2; + } + if (3 < k) { + out[19] = v2x3; + } + if (4 < k) { + out[20] = v2x4; + } + if (5 < k) { + out[21] = v2x5; + } + if (6 < k) { + out[22] = v2x6; + } + if (7 < k) { + out[23] = v2x7; + } + w2 += 8; + const int8_t v3x0 = 0 < k ? w3[0] : izp; + const int8_t v3x1 = 1 < k ? w3[1] : izp; + const int8_t v3x2 = 2 < k ? w3[2] : izp; + const int8_t v3x3 = 3 < k ? w3[3] : izp; + const int8_t v3x4 = 4 < k ? w3[4] : izp; + const int8_t v3x5 = 5 < k ? w3[5] : izp; + const int8_t v3x6 = 6 < k ? w3[6] : izp; + const int8_t v3x7 = 7 < k ? w3[7] : izp; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + if (0 < k) { + out[24] = v3x0; + } + if (1 < k) { + out[25] = v3x1; + } + if (2 < k) { + out[26] = v3x2; + } + if (3 < k) { + out[27] = v3x3; + } + if (4 < k) { + out[28] = v3x4; + } + if (5 < k) { + out[29] = v3x5; + } + if (6 < k) { + out[30] = v3x6; + } + if (7 < k) { + out[31] = v3x7; + } + w3 += 8; + const int8_t v4x0 = 0 < k ? w4[0] : izp; + const int8_t v4x1 = 1 < k ? w4[1] : izp; + const int8_t v4x2 = 2 < k ? w4[2] : izp; + const int8_t v4x3 = 3 < k ? w4[3] : izp; + const int8_t v4x4 = 4 < k ? w4[4] : izp; + const int8_t v4x5 = 5 < k ? w4[5] : izp; + const int8_t v4x6 = 6 < k ? w4[6] : izp; + const int8_t v4x7 = 7 < k ? w4[7] : izp; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + if (0 < k) { + out[32] = v4x0; + } + if (1 < k) { + out[33] = v4x1; + } + if (2 < k) { + out[34] = v4x2; + } + if (3 < k) { + out[35] = v4x3; + } + if (4 < k) { + out[36] = v4x4; + } + if (5 < k) { + out[37] = v4x5; + } + if (6 < k) { + out[38] = v4x6; + } + if (7 < k) { + out[39] = v4x7; + } + w4 += 8; + const int8_t v5x0 = 0 < k ? w5[0] : izp; + const int8_t v5x1 = 1 < k ? w5[1] : izp; + const int8_t v5x2 = 2 < k ? w5[2] : izp; + const int8_t v5x3 = 3 < k ? w5[3] : izp; + const int8_t v5x4 = 4 < k ? w5[4] : izp; + const int8_t v5x5 = 5 < k ? w5[5] : izp; + const int8_t v5x6 = 6 < k ? w5[6] : izp; + const int8_t v5x7 = 7 < k ? w5[7] : izp; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + if (0 < k) { + out[40] = v5x0; + } + if (1 < k) { + out[41] = v5x1; + } + if (2 < k) { + out[42] = v5x2; + } + if (3 < k) { + out[43] = v5x3; + } + if (4 < k) { + out[44] = v5x4; + } + if (5 < k) { + out[45] = v5x5; + } + if (6 < k) { + out[46] = v5x6; + } + if (7 < k) { + out[47] = v5x7; + } + w5 += 8; + const int8_t v6x0 = 0 < k ? w6[0] : izp; + const int8_t v6x1 = 1 < k ? w6[1] : izp; + const int8_t v6x2 = 2 < k ? w6[2] : izp; + const int8_t v6x3 = 3 < k ? w6[3] : izp; + const int8_t v6x4 = 4 < k ? w6[4] : izp; + const int8_t v6x5 = 5 < k ? w6[5] : izp; + const int8_t v6x6 = 6 < k ? w6[6] : izp; + const int8_t v6x7 = 7 < k ? w6[7] : izp; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + if (0 < k) { + out[48] = v6x0; + } + if (1 < k) { + out[49] = v6x1; + } + if (2 < k) { + out[50] = v6x2; + } + if (3 < k) { + out[51] = v6x3; + } + if (4 < k) { + out[52] = v6x4; + } + if (5 < k) { + out[53] = v6x5; + } + if (6 < k) { + out[54] = v6x6; + } + if (7 < k) { + out[55] = v6x7; + } + w6 += 8; + const int8_t v7x0 = 0 < k ? w7[0] : izp; + const int8_t v7x1 = 1 < k ? w7[1] : izp; + const int8_t v7x2 = 2 < k ? w7[2] : izp; + const int8_t v7x3 = 3 < k ? w7[3] : izp; + const int8_t v7x4 = 4 < k ? w7[4] : izp; + const int8_t v7x5 = 5 < k ? w7[5] : izp; + const int8_t v7x6 = 6 < k ? w7[6] : izp; + const int8_t v7x7 = 7 < k ? w7[7] : izp; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + if (0 < k) { + out[56] = v7x0; + } + if (1 < k) { + out[57] = v7x1; + } + if (2 < k) { + out[58] = v7x2; + } + if (3 < k) { + out[59] = v7x3; + } + if (4 < k) { + out[60] = v7x4; + } + if (5 < k) { + out[61] = v7x5; + } + if (6 < k) { + out[62] = v7x6; + } + if (7 < k) { + out[63] = v7x7; + } + w7 += 8; + out += 64; + } + + packed_b[0] -= ksum0 * izp; + packed_b[1] -= ksum1 * izp; + packed_b[2] -= ksum2 * izp; + packed_b[3] -= ksum3 * izp; + packed_b[4] -= ksum4 * izp; + packed_b[5] -= ksum5 * izp; + packed_b[6] -= ksum6 * izp; + packed_b[7] -= ksum7 * izp; + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w7; + } + + // NC remainder (1..7) + if XNN_UNLIKELY(n != 0) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((int32_t*) out) = *b++; + out += sizeof(int32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((int32_t*) out) = 0; + out += sizeof(int32_t); + } while (--nb != 0); + } + out += (8 - n) * sizeof(int32_t); + + // NR remainder has less than 8 rows so last row is not loaded + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + + uint32_t ksum0 = 0; + uint32_t ksum1 = 0; + uint32_t ksum2 = 0; + uint32_t ksum3 = 0; + uint32_t ksum4 = 0; + uint32_t ksum5 = 0; + uint32_t ksum6 = 0; + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + const int8_t v0x0 = w0[0]; + const int8_t v0x1 = w0[1]; + const int8_t v0x2 = w0[2]; + const int8_t v0x3 = w0[3]; + const int8_t v0x4 = w0[4]; + const int8_t v0x5 = w0[5]; + const int8_t v0x6 = w0[6]; + const int8_t v0x7 = w0[7]; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + out[0] = v0x0; + out[1] = v0x1; + out[2] = v0x2; + out[3] = v0x3; + out[4] = v0x4; + out[5] = v0x5; + out[6] = v0x6; + out[7] = v0x7; + w0 += 8; + const int8_t v1x0 = w1[0]; + const int8_t v1x1 = w1[1]; + const int8_t v1x2 = w1[2]; + const int8_t v1x3 = w1[3]; + const int8_t v1x4 = w1[4]; + const int8_t v1x5 = w1[5]; + const int8_t v1x6 = w1[6]; + const int8_t v1x7 = w1[7]; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + out[8] = v1x0; + out[9] = v1x1; + out[10] = v1x2; + out[11] = v1x3; + out[12] = v1x4; + out[13] = v1x5; + out[14] = v1x6; + out[15] = v1x7; + w1 += 8; + const int8_t v2x0 = w2[0]; + const int8_t v2x1 = w2[1]; + const int8_t v2x2 = w2[2]; + const int8_t v2x3 = w2[3]; + const int8_t v2x4 = w2[4]; + const int8_t v2x5 = w2[5]; + const int8_t v2x6 = w2[6]; + const int8_t v2x7 = w2[7]; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + out[16] = v2x0; + out[17] = v2x1; + out[18] = v2x2; + out[19] = v2x3; + out[20] = v2x4; + out[21] = v2x5; + out[22] = v2x6; + out[23] = v2x7; + w2 += 8; + const int8_t v3x0 = w3[0]; + const int8_t v3x1 = w3[1]; + const int8_t v3x2 = w3[2]; + const int8_t v3x3 = w3[3]; + const int8_t v3x4 = w3[4]; + const int8_t v3x5 = w3[5]; + const int8_t v3x6 = w3[6]; + const int8_t v3x7 = w3[7]; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + out[24] = v3x0; + out[25] = v3x1; + out[26] = v3x2; + out[27] = v3x3; + out[28] = v3x4; + out[29] = v3x5; + out[30] = v3x6; + out[31] = v3x7; + w3 += 8; + const int8_t v4x0 = w4[0]; + const int8_t v4x1 = w4[1]; + const int8_t v4x2 = w4[2]; + const int8_t v4x3 = w4[3]; + const int8_t v4x4 = w4[4]; + const int8_t v4x5 = w4[5]; + const int8_t v4x6 = w4[6]; + const int8_t v4x7 = w4[7]; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + out[32] = v4x0; + out[33] = v4x1; + out[34] = v4x2; + out[35] = v4x3; + out[36] = v4x4; + out[37] = v4x5; + out[38] = v4x6; + out[39] = v4x7; + w4 += 8; + const int8_t v5x0 = w5[0]; + const int8_t v5x1 = w5[1]; + const int8_t v5x2 = w5[2]; + const int8_t v5x3 = w5[3]; + const int8_t v5x4 = w5[4]; + const int8_t v5x5 = w5[5]; + const int8_t v5x6 = w5[6]; + const int8_t v5x7 = w5[7]; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + out[40] = v5x0; + out[41] = v5x1; + out[42] = v5x2; + out[43] = v5x3; + out[44] = v5x4; + out[45] = v5x5; + out[46] = v5x6; + out[47] = v5x7; + w5 += 8; + const int8_t v6x0 = w6[0]; + const int8_t v6x1 = w6[1]; + const int8_t v6x2 = w6[2]; + const int8_t v6x3 = w6[3]; + const int8_t v6x4 = w6[4]; + const int8_t v6x5 = w6[5]; + const int8_t v6x6 = w6[6]; + const int8_t v6x7 = w6[7]; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + out[48] = v6x0; + out[49] = v6x1; + out[50] = v6x2; + out[51] = v6x3; + out[52] = v6x4; + out[53] = v6x5; + out[54] = v6x6; + out[55] = v6x7; + w6 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + const int8_t v0x0 = 0 < k ? w0[0] : izp; + const int8_t v0x1 = 1 < k ? w0[1] : izp; + const int8_t v0x2 = 2 < k ? w0[2] : izp; + const int8_t v0x3 = 3 < k ? w0[3] : izp; + const int8_t v0x4 = 4 < k ? w0[4] : izp; + const int8_t v0x5 = 5 < k ? w0[5] : izp; + const int8_t v0x6 = 6 < k ? w0[6] : izp; + const int8_t v0x7 = 7 < k ? w0[7] : izp; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + if (0 < k) { + out[0] = v0x0; + } + if (1 < k) { + out[1] = v0x1; + } + if (2 < k) { + out[2] = v0x2; + } + if (3 < k) { + out[3] = v0x3; + } + if (4 < k) { + out[4] = v0x4; + } + if (5 < k) { + out[5] = v0x5; + } + if (6 < k) { + out[6] = v0x6; + } + if (7 < k) { + out[7] = v0x7; + } + w0 += 8; + const int8_t v1x0 = 0 < k ? w1[0] : izp; + const int8_t v1x1 = 1 < k ? w1[1] : izp; + const int8_t v1x2 = 2 < k ? w1[2] : izp; + const int8_t v1x3 = 3 < k ? w1[3] : izp; + const int8_t v1x4 = 4 < k ? w1[4] : izp; + const int8_t v1x5 = 5 < k ? w1[5] : izp; + const int8_t v1x6 = 6 < k ? w1[6] : izp; + const int8_t v1x7 = 7 < k ? w1[7] : izp; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + if (0 < k) { + out[8] = v1x0; + } + if (1 < k) { + out[9] = v1x1; + } + if (2 < k) { + out[10] = v1x2; + } + if (3 < k) { + out[11] = v1x3; + } + if (4 < k) { + out[12] = v1x4; + } + if (5 < k) { + out[13] = v1x5; + } + if (6 < k) { + out[14] = v1x6; + } + if (7 < k) { + out[15] = v1x7; + } + w1 += 8; + const int8_t v2x0 = 0 < k ? w2[0] : izp; + const int8_t v2x1 = 1 < k ? w2[1] : izp; + const int8_t v2x2 = 2 < k ? w2[2] : izp; + const int8_t v2x3 = 3 < k ? w2[3] : izp; + const int8_t v2x4 = 4 < k ? w2[4] : izp; + const int8_t v2x5 = 5 < k ? w2[5] : izp; + const int8_t v2x6 = 6 < k ? w2[6] : izp; + const int8_t v2x7 = 7 < k ? w2[7] : izp; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + if (0 < k) { + out[16] = v2x0; + } + if (1 < k) { + out[17] = v2x1; + } + if (2 < k) { + out[18] = v2x2; + } + if (3 < k) { + out[19] = v2x3; + } + if (4 < k) { + out[20] = v2x4; + } + if (5 < k) { + out[21] = v2x5; + } + if (6 < k) { + out[22] = v2x6; + } + if (7 < k) { + out[23] = v2x7; + } + w2 += 8; + const int8_t v3x0 = 0 < k ? w3[0] : izp; + const int8_t v3x1 = 1 < k ? w3[1] : izp; + const int8_t v3x2 = 2 < k ? w3[2] : izp; + const int8_t v3x3 = 3 < k ? w3[3] : izp; + const int8_t v3x4 = 4 < k ? w3[4] : izp; + const int8_t v3x5 = 5 < k ? w3[5] : izp; + const int8_t v3x6 = 6 < k ? w3[6] : izp; + const int8_t v3x7 = 7 < k ? w3[7] : izp; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + if (0 < k) { + out[24] = v3x0; + } + if (1 < k) { + out[25] = v3x1; + } + if (2 < k) { + out[26] = v3x2; + } + if (3 < k) { + out[27] = v3x3; + } + if (4 < k) { + out[28] = v3x4; + } + if (5 < k) { + out[29] = v3x5; + } + if (6 < k) { + out[30] = v3x6; + } + if (7 < k) { + out[31] = v3x7; + } + w3 += 8; + const int8_t v4x0 = 0 < k ? w4[0] : izp; + const int8_t v4x1 = 1 < k ? w4[1] : izp; + const int8_t v4x2 = 2 < k ? w4[2] : izp; + const int8_t v4x3 = 3 < k ? w4[3] : izp; + const int8_t v4x4 = 4 < k ? w4[4] : izp; + const int8_t v4x5 = 5 < k ? w4[5] : izp; + const int8_t v4x6 = 6 < k ? w4[6] : izp; + const int8_t v4x7 = 7 < k ? w4[7] : izp; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + if (0 < k) { + out[32] = v4x0; + } + if (1 < k) { + out[33] = v4x1; + } + if (2 < k) { + out[34] = v4x2; + } + if (3 < k) { + out[35] = v4x3; + } + if (4 < k) { + out[36] = v4x4; + } + if (5 < k) { + out[37] = v4x5; + } + if (6 < k) { + out[38] = v4x6; + } + if (7 < k) { + out[39] = v4x7; + } + w4 += 8; + const int8_t v5x0 = 0 < k ? w5[0] : izp; + const int8_t v5x1 = 1 < k ? w5[1] : izp; + const int8_t v5x2 = 2 < k ? w5[2] : izp; + const int8_t v5x3 = 3 < k ? w5[3] : izp; + const int8_t v5x4 = 4 < k ? w5[4] : izp; + const int8_t v5x5 = 5 < k ? w5[5] : izp; + const int8_t v5x6 = 6 < k ? w5[6] : izp; + const int8_t v5x7 = 7 < k ? w5[7] : izp; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + if (0 < k) { + out[40] = v5x0; + } + if (1 < k) { + out[41] = v5x1; + } + if (2 < k) { + out[42] = v5x2; + } + if (3 < k) { + out[43] = v5x3; + } + if (4 < k) { + out[44] = v5x4; + } + if (5 < k) { + out[45] = v5x5; + } + if (6 < k) { + out[46] = v5x6; + } + if (7 < k) { + out[47] = v5x7; + } + w5 += 8; + const int8_t v6x0 = 0 < k ? w6[0] : izp; + const int8_t v6x1 = 1 < k ? w6[1] : izp; + const int8_t v6x2 = 2 < k ? w6[2] : izp; + const int8_t v6x3 = 3 < k ? w6[3] : izp; + const int8_t v6x4 = 4 < k ? w6[4] : izp; + const int8_t v6x5 = 5 < k ? w6[5] : izp; + const int8_t v6x6 = 6 < k ? w6[6] : izp; + const int8_t v6x7 = 7 < k ? w6[7] : izp; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + if (0 < k) { + out[48] = v6x0; + } + if (1 < k) { + out[49] = v6x1; + } + if (2 < k) { + out[50] = v6x2; + } + if (3 < k) { + out[51] = v6x3; + } + if (4 < k) { + out[52] = v6x4; + } + if (5 < k) { + out[53] = v6x5; + } + if (6 < k) { + out[54] = v6x6; + } + if (7 < k) { + out[55] = v6x7; + } + w6 += 8; + out += 64; + } + + packed_b[0] -= ksum0 * izp; + packed_b[1] -= ksum1 * izp; + packed_b[2] -= ksum2 * izp; + packed_b[3] -= ksum3 * izp; + packed_b[4] -= ksum4 * izp; + packed_b[5] -= ksum5 * izp; + packed_b[6] -= ksum6 * izp; + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/qs8-packw.h b/src/qs8-packw/qs8-packw.h index d0d55e50e31..570273a66cd 100644 --- a/src/qs8-packw/qs8-packw.h +++ b/src/qs8-packw/qs8-packw.h @@ -21,6 +21,14 @@ XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x16c4__scalar, 16, 4, 1, 4, 1) XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x32c4__scalar, 32, 4, 1, 4, 1) XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x64c4__scalar, 64, 4, 1, 4, 1) +XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x8c8__scalar, 8, 8, 1, 8, 1) +XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x16c8__scalar, 16, 8, 1, 8, 1) + +// TODO: immintrin.h only provide _mm256_insert_epi64 for __x86_64__ +#if XNN_ENABLE_AVXVNNIINT8 && XNN_ARCH_X86_64 +XNN_QS8_UKERNEL(xnn_arch_x86_avxvnniint8, xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnniint8, 8, 8, 1, 8, 1) +#endif + #ifdef XNN_DEFINED_UKERNEL_WITH_PARAMS #undef XNN_DEFINED_UKERNEL_WITH_PARAMS #undef XNN_QS8_UKERNEL_WITH_PARAMS diff --git a/src/subgraph.c b/src/subgraph.c index d227add1624..4bc18c1b15b 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -13,11 +13,11 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocation-type.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" +#include "xnnpack/fp16.h" #include "xnnpack/hardware-config.h" #include "xnnpack/log.h" #include "xnnpack/math.h" diff --git a/src/subgraph/static-constant-pad.c b/src/subgraph/static-constant-pad.c index 5bad5e90fb4..ab75daff4e0 100644 --- a/src/subgraph/static-constant-pad.c +++ b/src/subgraph/static-constant-pad.c @@ -9,9 +9,9 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/common.h" +#include "xnnpack/fp16.h" #include "xnnpack/log.h" #include "xnnpack/math.h" #include "xnnpack/node-type.h" diff --git a/src/x8-packw/kr-avxvnniint8.c.in b/src/x8-packw/kr-avxvnniint8.c.in new file mode 100644 index 00000000000..fc6fb0a2f77 --- /dev/null +++ b/src/x8-packw/kr-avxvnniint8.c.in @@ -0,0 +1,401 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$assert NR == 8 +$assert KR == 8 +$assert TYPE in ["int8_t"] + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" + +$BITS = {"int8_t": 8}[TYPE] +$BTYPE = {"int8_t": "uint32_t"}[TYPE] +$WTYPE = {"int8_t": "int8_t"}[TYPE] +void xnn_qs${BITS}_packw_gemm_goi_ukernel_x${NR}c${KR}__avxvnniint8( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const ${WTYPE}* weights, + $if BITS == 8: + const int32_t* bias, + $else: + const ${WTYPE}* bias, + const void* scale, + ${WTYPE}* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == ${NR}); + assert(kr == ${KR}); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + +// TODO: immintrin.h only provide _mm256_insert_epi64 for __x86_64__ +#if defined(__x86_64__) + ${TYPE}* out = (${TYPE}*) packed_weights; + const ${BTYPE}* b = (const ${BTYPE}*) bias; + $if BITS == 8: + const int8_t izp = params ? ((const struct xnn_qs8_packw_params*) params)->input_zero_point : 0; + __m256i vzeropoint = _mm256_set1_epi8(izp); + + do { + // NC main loop multiple of ${NR} + const ${TYPE}* w0 = (const ${TYPE}*) weights; + size_t n = nc; + for (;n >= ${NR}; n -= ${NR}) { + __m256i vacc0124x8 = _mm256_setzero_si256(); + __m256i vacc4567x8 = _mm256_setzero_si256(); + + $if BITS == 8: + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb = _mm256_loadu_si256((const __m256i*) b); + _mm256_storeu_si256((__m256i*) out, vb); + b += ${NR}; + } else { + _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + } + $if BTYPE == TYPE: + out += ${NR}; + $else: + out += ${NR} * sizeof(${BTYPE}); + + $for N in range(1, NR): + const ${TYPE}* w${N} = w${N-1} + kc; + + // KC main loop multiple of ${NR}x${KR} + size_t k = kc; + for (; k >= ${KR}; k -= ${KR}) { + __m256i v0123x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w0)); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w1, 1); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w2, 2); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w3, 3); + + __m256i v4567x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w4)); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w5, 1); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w6, 2); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w7, 3); + + $if BITS == 8: + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + + $for N in range(NR): + w${N} += ${KR}; + out += ${NR*KR}; + } + + // KC remainder 1..KR-1 + if (k != 0) { + __m256i v0123x8 = vzeropoint; + __m256i v4567x8 = vzeropoint; + + if (k & 4) { + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); + + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + $for N in range(NR): + w${N} += 4; + } + if (k & 2) { + if (k & 4) { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + } else { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + } + + $for N in range(NR): + w${N} += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + } + else { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + } + + $for N in range(NR): + w${N} += 1; + } + + $if BITS == 8: + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + + out += ${NR*KR}; + } + + $if BITS == 8: + __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); + vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); + _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + out = (${TYPE}*) ((uintptr_t) out + extra_bytes); + w0 = w${NR-1}; + } + + // NC remainder (1..${NR-1}) + if XNN_UNLIKELY(n != 0) { + $if BITS == 8: + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + $if BTYPE == TYPE: + *out++ = *b++; + $else: + *((${BTYPE}*) out) = *b++; + out += sizeof(${BTYPE}); + } while (--nb != 0); + } else { + size_t nb = n; + do { + $if BTYPE == TYPE: + *out++ = 0; + $else: + *((${BTYPE}*) out) = 0; + out += sizeof(${BTYPE}); + } while (--nb != 0); + } + $if BTYPE == TYPE: + out += (${NR} - n); + $else: + out += (${NR} - n) * sizeof(${BTYPE}); + + $if NR > 2: + $for N in range(1, NR): + const ${TYPE}* w${N} = w${N-1} + kc; + $if N % 2 == 0: + if XNN_UNPREDICTABLE(n <= ${N}) { + w${N} = w${N-1}; + } + $else: + if XNN_UNPREDICTABLE(n < ${N+1}) { + w${N} = w${N-1}; + } + + $if BITS == 8: + __m256i vacc0124x8 = _mm256_setzero_si256(); + __m256i vacc4567x8 = _mm256_setzero_si256(); + + // KC main loop multiple of ${NR}x${KR} + size_t k = kc; + for (; k >= ${KR}; k -= ${KR}) { + __m256i v0123x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w0)); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w1, 1); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w2, 2); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w3, 3); + + __m256i v4567x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w4)); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w5, 1); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w6, 2); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w7, 3); + + $if BITS == 8: + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + + $for N in range(NR): + w${N} += ${KR}; + out += ${NR*KR}; + } + + // KC remainder of 1..${KR-1} + if (k != 0) { + __m256i v0123x8 = vzeropoint; + __m256i v4567x8 = vzeropoint; + + if (k & 4) { + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); + + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + $for N in range(NR): + w${N} += 4; + } + if (k & 2) { + if (k & 4) { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + } else { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + } + + $for N in range(NR): + w${N} += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + } + else { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + } + + $for N in range(NR): + w${N} += 1; + } + + $if BITS == 8: + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + + out += ${NR*KR}; + } + + $if BITS == 8: + __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); + vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); + _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + out = (${TYPE}*) ((uintptr_t) out + extra_bytes); + } + weights += nc * kc; + } while (--g != 0); +#endif // defined(__x86_64__) +} diff --git a/src/xnnpack/fp16.h b/src/xnnpack/fp16.h new file mode 100644 index 00000000000..dc3b47aa8dc --- /dev/null +++ b/src/xnnpack/fp16.h @@ -0,0 +1,179 @@ +#ifndef THIRD_PARTY_XNNPACK_SRC_XNNPACK_FP16_H_ +#define THIRD_PARTY_XNNPACK_SRC_XNNPACK_FP16_H_ + +#include + +// This file is an excerpt from https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h, +// including only the minimal functionality we need in XNNPACK. This works around some issues +// that we haven't been able to fix upstream (https://github.com/Maratyszcza/FP16/pull/32). See also: +// - https://github.com/microsoft/onnxruntime/pull/22294/files +// - https://github.com/google/XNNPACK/issues/6989 +// We also don't need a lot of the functionality in the upstream library. + +static inline float fp32_from_bits(uint32_t w) { + union { + uint32_t as_bits; + float as_value; + } fp32 = { w }; + return fp32.as_value; +} + +static inline uint32_t fp32_to_bits(float f) { + union { + float as_value; + uint32_t as_bits; + } fp32 = { f }; + return fp32.as_bits; +} + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline float fp16_ieee_to_fp32_value(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; + + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become mantissa and exponent + * of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, there are some adjustments to the exponent: + * - The exponent needs to be corrected by the difference in exponent bias between single-precision and half-precision + * formats (0x7F - 0xF = 0x70) + * - Inf and NaN values in the inputs should become Inf and NaN values after conversion to the single-precision number. + * Therefore, if the biased exponent of the half-precision input was 0x1F (max possible value), the biased exponent + * of the single-precision output must be 0xFF (max possible value). We do this correction in two steps: + * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset below) rather than by 0x70 suggested + * by the difference in the exponent bias (see above). + * - Then we multiply the single-precision result of exponent adjustment by 2**(-112) to reverse the effect of + * exponent adjustment by 0xE0 less the necessary exponent adjustment by 0x70 due to difference in exponent bias. + * The floating-point multiplication hardware would ensure than Inf and NaN would retain their value on at least + * partially IEEE754-compliant implementations. + * + * Note that the above operations do not handle denormal inputs (where biased exponent == 0). However, they also do not + * operate on denormal inputs, and do not produce denormal results. + */ + const uint32_t exp_offset = UINT32_C(0xE0) << 23; +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float exp_scale = 0x1.0p-112f; +#else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); +#endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + /* + * Convert denormalized half-precision inputs into single-precision results (always normalized). + * Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has on-zero bits. + * First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the same mantissa and thehalf-precision input + * and with an exponent which would scale the corresponding mantissa bits to 2**(-24). + * A normalized single-precision floating-point number is represented as: + * FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127) + * Therefore, when the biased exponent is 126, a unit change in the mantissa of the input denormalized half-precision + * number causes a change of the constructud single-precision number by 2**(-24), i.e. the same ammount. + * + * The last step is to adjust the bias of the constructed single-precision number. When the input half-precision number + * is zero, the constructed single-precision number has the value of + * FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5 + * Therefore, we need to subtract 0.5 from the constructed single-precision number to get the numerical equivalent of + * the input half-precision number. + */ + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or as a denormalized number, depending on the + * input exponent. The variable two_w contains input exponent in bits 27-31, therefore if its smaller than 2**27, the + * input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign of the input number. + */ + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a 16-bit floating-point number in + * IEEE half-precision format, in bit representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline uint16_t fp16_ieee_from_fp32_value(float f) { +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; +#else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); +#endif + const uint32_t w = fp32_to_bits(f); + const float abs_f = fp32_from_bits(w & UINT32_C(0x7FFFFFFF)); + float base = (abs_f * scale_to_inf) * scale_to_zero; + + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); +} + + +#endif // THIRD_PARTY_XNNPACK_SRC_XNNPACK_FP16_H_ diff --git a/src/xnnpack/math.h b/src/xnnpack/math.h index dab6e037853..893e038a312 100644 --- a/src/xnnpack/math.h +++ b/src/xnnpack/math.h @@ -18,8 +18,8 @@ #include // For _rotl. #endif -#include #include "xnnpack/common.h" +#include "xnnpack/fp16.h" // stdlib.h from Windows 10 SDK defines min & max macros. // Undefine them before defining the corresponding functions. diff --git a/src/xnnpack/quantization.h b/src/xnnpack/quantization.h index 95f832bada5..0884ba41a2b 100644 --- a/src/xnnpack/quantization.h +++ b/src/xnnpack/quantization.h @@ -9,7 +9,6 @@ #include #include -#include #include "xnnpack/math.h" #include "xnnpack/microparams.h" diff --git a/src/xnnpack/simd/f16-scalar.h b/src/xnnpack/simd/f16-scalar.h index c24076b282b..e95425c5b9b 100644 --- a/src/xnnpack/simd/f16-scalar.h +++ b/src/xnnpack/simd/f16-scalar.h @@ -11,8 +11,8 @@ #include #include -#include #include "xnnpack/common.h" +#include "xnnpack/fp16.h" // SIMD vector type for f16 using SCALAR. typedef uint16_t xnn_simd_f16_t; diff --git a/test/BUILD.bazel b/test/BUILD.bazel index b36d2b0c0f8..603357df6d6 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -26,11 +26,11 @@ load( MICROKERNEL_TEST_DEPS = [ ":next_prime", ":replicable_random_device", - "@FP16", "//:aligned_allocator", "//:all_microkernels", "//:allocator", "//:common", + "//:fp16", "//:isa_checks", "//:math", "//:memory", @@ -46,12 +46,12 @@ MICROKERNEL_TEST_DEPS = [ OPERATOR_TEST_DEPS = [ ":replicable_random_device", - "@FP16", "@pthreadpool", "//:aligned_allocator", "//:allocator", "//:cache", "//:common", + "//:fp16", "//:internal", "//:math", "//:microkernel_configs", @@ -189,8 +189,8 @@ sh_test( copts = xnnpack_simd_copts_for_arch(arch), deps = [ ":replicable_random_device", - "@FP16", "//:common", + "//:fp16", "//:isa_checks", "//:microkernels_h", ], @@ -1769,7 +1769,6 @@ xnnpack_cxx_library( deps = [ ":replicable_random_device", ":subgraph_unary_tester", - "@FP16", "//:XNNPACK", "//:math", "//:node_type", @@ -1885,7 +1884,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -1915,7 +1913,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -1932,7 +1929,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operators", @@ -1953,7 +1949,6 @@ xnnpack_unit_test( deps = [ ":convolution_test_helpers", ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -1974,7 +1969,6 @@ xnnpack_unit_test( shard_count = 5, deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operator_utils", @@ -1991,7 +1985,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operators", @@ -2007,7 +2000,6 @@ xnnpack_unit_test( deps = [ ":convolution_test_helpers", ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2026,7 +2018,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operators", @@ -2046,7 +2037,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2068,7 +2058,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2086,7 +2075,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2104,7 +2092,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2121,7 +2108,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2138,7 +2124,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operator_utils", @@ -2155,7 +2140,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operators", @@ -2185,7 +2169,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2201,7 +2184,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2230,7 +2212,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operators", @@ -2348,7 +2329,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:allocation_type", "//:allocator", diff --git a/test/abs.cc b/test/abs.cc index ad1e8401293..7aaabe56139 100644 --- a/test/abs.cc +++ b/test/abs.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/average-pooling-2d.cc b/test/average-pooling-2d.cc index 4602f2aded5..8e9d2b58c06 100644 --- a/test/average-pooling-2d.cc +++ b/test/average-pooling-2d.cc @@ -13,7 +13,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/avgpool-microkernel-tester.h b/test/avgpool-microkernel-tester.h index 35db8fe1bb4..68e4bd4fb39 100644 --- a/test/avgpool-microkernel-tester.h +++ b/test/avgpool-microkernel-tester.h @@ -20,7 +20,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/bankers-rounding.cc b/test/bankers-rounding.cc index 0021bdad455..fe4ab88f153 100644 --- a/test/bankers-rounding.cc +++ b/test/bankers-rounding.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/batch-matrix-multiply.cc b/test/batch-matrix-multiply.cc index 55fbe9eeb3e..53bb93df183 100644 --- a/test/batch-matrix-multiply.cc +++ b/test/batch-matrix-multiply.cc @@ -20,7 +20,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/ceiling.cc b/test/ceiling.cc index 01168b81e60..16c1f9c4ce6 100644 --- a/test/ceiling.cc +++ b/test/ceiling.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/clamp.cc b/test/clamp.cc index f9c87f57769..e5db94311f0 100644 --- a/test/clamp.cc +++ b/test/clamp.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/concatenate2.cc b/test/concatenate2.cc index 06c6c351860..245f114fa3a 100644 --- a/test/concatenate2.cc +++ b/test/concatenate2.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/concatenate3.cc b/test/concatenate3.cc index 9a7888effff..cb4f56a4b33 100644 --- a/test/concatenate3.cc +++ b/test/concatenate3.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/concatenate4.cc b/test/concatenate4.cc index 0d6ac91a838..992ccf9668b 100644 --- a/test/concatenate4.cc +++ b/test/concatenate4.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/concatenate5.cc b/test/concatenate5.cc index 8dfea193ce7..2934812315c 100644 --- a/test/concatenate5.cc +++ b/test/concatenate5.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/conv-hwc2chw-microkernel-tester.h b/test/conv-hwc2chw-microkernel-tester.h index d354992bab4..eaca4b86116 100644 --- a/test/conv-hwc2chw-microkernel-tester.h +++ b/test/conv-hwc2chw-microkernel-tester.h @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/convert-operator-tester.h b/test/convert-operator-tester.h index 78724161ed6..9a9641a32d9 100644 --- a/test/convert-operator-tester.h +++ b/test/convert-operator-tester.h @@ -17,7 +17,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/config-types.h" #include "xnnpack/config.h" diff --git a/test/convert.cc b/test/convert.cc index 8b7933b1b99..5c418403e18 100644 --- a/test/convert.cc +++ b/test/convert.cc @@ -12,7 +12,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/convolution-2d.cc b/test/convolution-2d.cc index a029865d309..8e85f6cb916 100644 --- a/test/convolution-2d.cc +++ b/test/convolution-2d.cc @@ -14,7 +14,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/convolution-operator-tester.h b/test/convolution-operator-tester.h index ad155b26759..0a59b20b5b8 100644 --- a/test/convolution-operator-tester.h +++ b/test/convolution-operator-tester.h @@ -21,7 +21,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/cache.h" diff --git a/test/copy.cc b/test/copy.cc index 8683530503b..a19cb3d34ba 100644 --- a/test/copy.cc +++ b/test/copy.cc @@ -11,7 +11,6 @@ #include // For std::unique_ptr. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/deconvolution-2d.cc b/test/deconvolution-2d.cc index 1d7d53bc314..6e4f6cda19c 100644 --- a/test/deconvolution-2d.cc +++ b/test/deconvolution-2d.cc @@ -14,7 +14,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator-utils.h" diff --git a/test/deconvolution-operator-tester.h b/test/deconvolution-operator-tester.h index 21d108ed9a4..c74e75b6650 100644 --- a/test/deconvolution-operator-tester.h +++ b/test/deconvolution-operator-tester.h @@ -23,7 +23,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/cache.h" #include "xnnpack/common.h" diff --git a/test/depth-to-space-2d.cc b/test/depth-to-space-2d.cc index 924f2e18999..f20ee1de3a0 100644 --- a/test/depth-to-space-2d.cc +++ b/test/depth-to-space-2d.cc @@ -17,7 +17,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/depthwise-convolution-2d.cc b/test/depthwise-convolution-2d.cc index 2f008078792..923c0e67762 100644 --- a/test/depthwise-convolution-2d.cc +++ b/test/depthwise-convolution-2d.cc @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/dwconv-microkernel-tester.cc b/test/dwconv-microkernel-tester.cc index d29e22e7572..942e5d284d2 100644 --- a/test/dwconv-microkernel-tester.cc +++ b/test/dwconv-microkernel-tester.cc @@ -21,7 +21,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/dwconv2d-microkernel-tester.h b/test/dwconv2d-microkernel-tester.h index 2af4605c7bc..15a8438306c 100644 --- a/test/dwconv2d-microkernel-tester.h +++ b/test/dwconv2d-microkernel-tester.h @@ -18,7 +18,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/elu.cc b/test/elu.cc index 1ad5b23fbd6..ccb49d61d68 100644 --- a/test/elu.cc +++ b/test/elu.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/even-split2.cc b/test/even-split2.cc index 221b5f00f09..ab00d5f777d 100644 --- a/test/even-split2.cc +++ b/test/even-split2.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/even-split3.cc b/test/even-split3.cc index 99979aa0965..cd697a325ab 100644 --- a/test/even-split3.cc +++ b/test/even-split3.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/even-split4.cc b/test/even-split4.cc index 13f872aaa0b..22d8ab52f82 100644 --- a/test/even-split4.cc +++ b/test/even-split4.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/f16-simd-scalar.cc b/test/f16-simd-scalar.cc index d5d100ede7a..3fa9cda6100 100644 --- a/test/f16-simd-scalar.cc +++ b/test/f16-simd-scalar.cc @@ -17,9 +17,9 @@ #include #include -#include #include "xnnpack/isa-checks.h" #include "xnnpack/simd/f16-scalar.h" +#include "xnnpack/fp16.h" #include "replicable_random_device.h" namespace xnnpack { diff --git a/test/f16-simd.cc.in b/test/f16-simd.cc.in index c6de215ce7f..84ebd2d78fb 100644 --- a/test/f16-simd.cc.in +++ b/test/f16-simd.cc.in @@ -19,9 +19,9 @@ $if ARCH_MACRO: #include #include -#include #include "xnnpack/isa-checks.h" #include "xnnpack/simd/f16-${ARCH}.h" +#include "xnnpack/fp16.h" #include "replicable_random_device.h" namespace xnnpack { diff --git a/test/floor.cc b/test/floor.cc index 0f62c175ffd..790b668d726 100644 --- a/test/floor.cc +++ b/test/floor.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/fully-connected-operator-tester.h b/test/fully-connected-operator-tester.h index a84bc876da7..c7e5695ca42 100644 --- a/test/fully-connected-operator-tester.h +++ b/test/fully-connected-operator-tester.h @@ -20,7 +20,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/cache.h" #include "xnnpack/common.h" diff --git a/test/fully-connected.cc b/test/fully-connected.cc index 54624914ce0..a8a11097c10 100644 --- a/test/fully-connected.cc +++ b/test/fully-connected.cc @@ -18,7 +18,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/gavgpool-cw-microkernel-tester.h b/test/gavgpool-cw-microkernel-tester.h index 077c7c0775b..fae2cb8e9d2 100644 --- a/test/gavgpool-cw-microkernel-tester.h +++ b/test/gavgpool-cw-microkernel-tester.h @@ -15,8 +15,8 @@ #include #include -#include #include "xnnpack.h" +#include "xnnpack/fp16.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" #include "replicable_random_device.h" diff --git a/test/gavgpool-microkernel-tester.h b/test/gavgpool-microkernel-tester.h index 78a6db9604c..723bf879e94 100644 --- a/test/gavgpool-microkernel-tester.h +++ b/test/gavgpool-microkernel-tester.h @@ -19,7 +19,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/gemm-microkernel-tester.cc b/test/gemm-microkernel-tester.cc index 2e5bf67566e..8d3d90b157d 100644 --- a/test/gemm-microkernel-tester.cc +++ b/test/gemm-microkernel-tester.cc @@ -14,8 +14,6 @@ #include #include -#include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/aligned-allocator.h" diff --git a/test/global-average-pooling-1d.cc b/test/global-average-pooling-1d.cc index 7cd48566976..c7a472835e4 100644 --- a/test/global-average-pooling-1d.cc +++ b/test/global-average-pooling-1d.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/global-average-pooling-2d.cc b/test/global-average-pooling-2d.cc index 9a51ac9aa16..7f83beb8fb3 100644 --- a/test/global-average-pooling-2d.cc +++ b/test/global-average-pooling-2d.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/global-sum-pooling-1d.cc b/test/global-sum-pooling-1d.cc index fcf429581e3..45785d8fa37 100644 --- a/test/global-sum-pooling-1d.cc +++ b/test/global-sum-pooling-1d.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/global-sum-pooling-2d.cc b/test/global-sum-pooling-2d.cc index cbe47848380..8d618494db5 100644 --- a/test/global-sum-pooling-2d.cc +++ b/test/global-sum-pooling-2d.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/hardswish.cc b/test/hardswish.cc index 515e9e62faa..728947595f4 100644 --- a/test/hardswish.cc +++ b/test/hardswish.cc @@ -13,7 +13,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/ibilinear-microkernel-tester.h b/test/ibilinear-microkernel-tester.h index f9fdecdbf8f..c13e2a7f409 100644 --- a/test/ibilinear-microkernel-tester.h +++ b/test/ibilinear-microkernel-tester.h @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/math.h" diff --git a/test/leaky-relu.cc b/test/leaky-relu.cc index dfba1ccce39..0b9d2180442 100644 --- a/test/leaky-relu.cc +++ b/test/leaky-relu.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/max-pooling-2d.cc b/test/max-pooling-2d.cc index 241ec93ab46..2c89bb58efc 100644 --- a/test/max-pooling-2d.cc +++ b/test/max-pooling-2d.cc @@ -14,7 +14,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator-utils.h" diff --git a/test/maxpool-microkernel-tester.h b/test/maxpool-microkernel-tester.h index faedb72b52b..7587b033cd5 100644 --- a/test/maxpool-microkernel-tester.h +++ b/test/maxpool-microkernel-tester.h @@ -20,7 +20,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" diff --git a/test/negate.cc b/test/negate.cc index 9ae0008b3f9..a36ef815844 100644 --- a/test/negate.cc +++ b/test/negate.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/packing.cc b/test/packing.cc index eeebc0946d6..4c6bd1910bd 100644 --- a/test/packing.cc +++ b/test/packing.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack/math.h" #include "xnnpack/microkernel-utils.h" #include "xnnpack/microparams-init.h" diff --git a/test/prelu-microkernel-tester.h b/test/prelu-microkernel-tester.h index 10ec6e4e1f5..b50bab36775 100644 --- a/test/prelu-microkernel-tester.h +++ b/test/prelu-microkernel-tester.h @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/prelu.cc b/test/prelu.cc index 5f4a65daf61..698eee99a34 100644 --- a/test/prelu.cc +++ b/test/prelu.cc @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/raddstoreexpminusmax-microkernel-tester.h b/test/raddstoreexpminusmax-microkernel-tester.h index 793e161a8ed..4ea512f71a3 100644 --- a/test/raddstoreexpminusmax-microkernel-tester.h +++ b/test/raddstoreexpminusmax-microkernel-tester.h @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" diff --git a/test/rdsum-microkernel-tester.h b/test/rdsum-microkernel-tester.h index dc08fbf8700..0dc57804777 100644 --- a/test/rdsum-microkernel-tester.h +++ b/test/rdsum-microkernel-tester.h @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/reciprocal-square-root.cc b/test/reciprocal-square-root.cc index 8f7c6e1baae..b5734767615 100644 --- a/test/reciprocal-square-root.cc +++ b/test/reciprocal-square-root.cc @@ -10,7 +10,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/rsum-microkernel-tester.h b/test/rsum-microkernel-tester.h index dee79bb9f0c..9f2c94b9419 100644 --- a/test/rsum-microkernel-tester.h +++ b/test/rsum-microkernel-tester.h @@ -17,7 +17,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" diff --git a/test/scaled-dot-product-attention.cc b/test/scaled-dot-product-attention.cc index 0d6fd61e381..c8782608783 100644 --- a/test/scaled-dot-product-attention.cc +++ b/test/scaled-dot-product-attention.cc @@ -17,7 +17,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/sigmoid.cc b/test/sigmoid.cc index 82ffa4db623..df1935063ef 100644 --- a/test/sigmoid.cc +++ b/test/sigmoid.cc @@ -10,7 +10,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/softmax.cc b/test/softmax.cc index b17134654a2..458638b77a9 100644 --- a/test/softmax.cc +++ b/test/softmax.cc @@ -12,7 +12,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/space-to-depth-2d.cc b/test/space-to-depth-2d.cc index 9fc996e3777..66869423ad6 100644 --- a/test/space-to-depth-2d.cc +++ b/test/space-to-depth-2d.cc @@ -13,7 +13,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/spmm-microkernel-tester.h b/test/spmm-microkernel-tester.h index 4ffa7cf6f89..bd8a115cd2f 100644 --- a/test/spmm-microkernel-tester.h +++ b/test/spmm-microkernel-tester.h @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" diff --git a/test/square-root.cc b/test/square-root.cc index 1ed486e0db3..23ef94f803a 100644 --- a/test/square-root.cc +++ b/test/square-root.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/square.cc b/test/square.cc index f183bae5603..4fbc26b979d 100644 --- a/test/square.cc +++ b/test/square.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/static-constant-pad.cc b/test/static-constant-pad.cc index c9ca1080777..77f0d813d7f 100644 --- a/test/static-constant-pad.cc +++ b/test/static-constant-pad.cc @@ -13,7 +13,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/math.h" #include "xnnpack/node-type.h" diff --git a/test/static-expand-dims.cc b/test/static-expand-dims.cc index d608b5dddcd..5fb68440678 100644 --- a/test/static-expand-dims.cc +++ b/test/static-expand-dims.cc @@ -15,7 +15,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/static-mean.cc b/test/static-mean.cc index 92b7a9d309d..a55f63c4acf 100644 --- a/test/static-mean.cc +++ b/test/static-mean.cc @@ -16,7 +16,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/static-reshape.cc b/test/static-reshape.cc index 0745cbd2273..6107acbbc04 100644 --- a/test/static-reshape.cc +++ b/test/static-reshape.cc @@ -15,7 +15,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/static-resize-bilinear-2d.cc b/test/static-resize-bilinear-2d.cc index 250c048995b..6ea0dabddd2 100644 --- a/test/static-resize-bilinear-2d.cc +++ b/test/static-resize-bilinear-2d.cc @@ -14,7 +14,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/static-slice.cc b/test/static-slice.cc index 81c25ea0d0a..e56e239a2a4 100644 --- a/test/static-slice.cc +++ b/test/static-slice.cc @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/static-transpose.cc b/test/static-transpose.cc index 52dc729c22e..d5a92cd7cbe 100644 --- a/test/static-transpose.cc +++ b/test/static-transpose.cc @@ -14,7 +14,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/subgraph-fp16.cc b/test/subgraph-fp16.cc index 7653908077b..ae25e6e6d86 100644 --- a/test/subgraph-fp16.cc +++ b/test/subgraph-fp16.cc @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocation-type.h" #include "xnnpack/node-type.h" diff --git a/test/tanh-operator-tester.h b/test/tanh-operator-tester.h index f43e5d7e997..755a6def164 100644 --- a/test/tanh-operator-tester.h +++ b/test/tanh-operator-tester.h @@ -17,7 +17,6 @@ #include #include -#include #include "xnnpack.h" #include "replicable_random_device.h" diff --git a/test/tanh.cc b/test/tanh.cc index 2d01af0c534..bd0cf712cb3 100644 --- a/test/tanh.cc +++ b/test/tanh.cc @@ -10,7 +10,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/unary-operator-tester.cc b/test/unary-operator-tester.cc index 74912831187..f6d6545eb45 100644 --- a/test/unary-operator-tester.cc +++ b/test/unary-operator-tester.cc @@ -19,7 +19,6 @@ #include #include -#include #include "xnnpack.h" #include "replicable_random_device.h" diff --git a/test/vbinary-microkernel-tester.cc b/test/vbinary-microkernel-tester.cc index 00515ad5713..f4a48916660 100644 --- a/test/vbinary-microkernel-tester.cc +++ b/test/vbinary-microkernel-tester.cc @@ -20,7 +20,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams-init.h" diff --git a/test/vcmul-microkernel-tester.h b/test/vcmul-microkernel-tester.h index 4c7461580b9..4d875874c38 100644 --- a/test/vcmul-microkernel-tester.h +++ b/test/vcmul-microkernel-tester.h @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/isa-checks.h" #include "xnnpack/microfnptr.h" diff --git a/test/vcvt-microkernel-tester.cc b/test/vcvt-microkernel-tester.cc index e17d555edc5..44129ffac4c 100644 --- a/test/vcvt-microkernel-tester.cc +++ b/test/vcvt-microkernel-tester.cc @@ -20,7 +20,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/math.h" #include "xnnpack/microfnptr.h" diff --git a/test/vmulcaddc-microkernel-tester.h b/test/vmulcaddc-microkernel-tester.h index c5f15d12aaf..13afb0cb58e 100644 --- a/test/vmulcaddc-microkernel-tester.h +++ b/test/vmulcaddc-microkernel-tester.h @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/vunary-microkernel-tester.cc b/test/vunary-microkernel-tester.cc index ffa6e29a5fe..07a7ad6e91b 100644 --- a/test/vunary-microkernel-tester.cc +++ b/test/vunary-microkernel-tester.cc @@ -17,7 +17,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/common.h" #include "xnnpack/microfnptr.h" diff --git a/test/vunary-microkernel-tester.h b/test/vunary-microkernel-tester.h index 8293d2769c9..a65b6646570 100644 --- a/test/vunary-microkernel-tester.h +++ b/test/vunary-microkernel-tester.h @@ -18,7 +18,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/microfnptr.h" #include "replicable_random_device.h"