diff --git a/CMakeLists.txt b/CMakeLists.txt index 1877c8dec45c..0e234597c53b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,7 +44,12 @@ SET(CMAKE_CXX_EXTENSIONS NO) # ---[ Options. SET(XNNPACK_LIBRARY_TYPE "default" CACHE STRING "Type of library (shared, static, or default) to build") SET_PROPERTY(CACHE XNNPACK_LIBRARY_TYPE PROPERTY STRINGS default static shared) -OPTION(XNNPACK_ENABLE_ASSEMBLY "Build XNNPACK with assembly micro-kernels" ON) +IF(CMAKE_C_COMPILER_ID STREQUAL "MSVC") + # Disable assembly when using MSVC until support is added. + OPTION(XNNPACK_ENABLE_ASSEMBLY "Build XNNPACK with assembly micro-kernels" OFF) +ELSE() + OPTION(XNNPACK_ENABLE_ASSEMBLY "Build XNNPACK with assembly micro-kernels" ON) +ENDIF() OPTION(XNNPACK_ENABLE_MEMOPT "Build XNNPACK with optimized memory allocation scheme" ON) OPTION(XNNPACK_ENABLE_SPARSE "Build XNNPACK with graph rewriting for sparse inference" ON) OPTION(XNNPACK_ENABLE_GEMM_M_SPECIALIZATION "Build XNNPACK with support for selecting microkernel with different MR" ON) @@ -658,6 +663,9 @@ IF(XNNPACK_TARGET_PROCESSOR MATCHES "^x86(_64)?$") LIST(APPEND PROD_MICROKERNEL_SRCS ${PROD_F16C_MICROKERNEL_SRCS}) LIST(APPEND PROD_MICROKERNEL_SRCS ${PROD_FMA3_MICROKERNEL_SRCS}) LIST(APPEND PROD_MICROKERNEL_SRCS ${PROD_AVX2_MICROKERNEL_SRCS}) + IF(XNNPACK_ENABLE_ASSEMBLY AND XNNPACK_TARGET_PROCESSOR MATCHES "x86_64") + LIST(APPEND PROD_MICROKERNEL_SRCS ${PROD_AMD64_ASM_MICROKERNEL_SRCS}) + ENDIF() IF(XNNPACK_ENABLE_AVX512AMX) LIST(APPEND PROD_MICROKERNEL_SRCS ${PROD_AVX512AMX_MICROKERNEL_SRCS}) ENDIF() @@ -705,6 +713,9 @@ IF(XNNPACK_TARGET_PROCESSOR MATCHES "^x86(_64)?$") LIST(APPEND NON_PROD_MICROKERNEL_SRCS ${NON_PROD_F16C_MICROKERNEL_SRCS}) LIST(APPEND NON_PROD_MICROKERNEL_SRCS ${NON_PROD_FMA3_MICROKERNEL_SRCS}) LIST(APPEND NON_PROD_MICROKERNEL_SRCS ${NON_PROD_AVX2_MICROKERNEL_SRCS}) + IF(XNNPACK_ENABLE_ASSEMBLY AND XNNPACK_TARGET_PROCESSOR MATCHES "x86_64") + LIST(APPEND NON_PROD_MICROKERNEL_SRCS ${NON_PROD_AMD64_ASM_MICROKERNEL_SRCS}) + ENDIF() IF(XNNPACK_ENABLE_AVX512AMX) LIST(APPEND NON_PROD_MICROKERNEL_SRCS ${NON_PROD_AVX512AMX_MICROKERNEL_SRCS}) ENDIF() diff --git a/bench/f32-gemm-minmax.cc b/bench/f32-gemm-minmax.cc index 5ee34e2f4106..f10a4f901f5f 100644 --- a/bench/f32-gemm-minmax.cc +++ b/bench/f32-gemm-minmax.cc @@ -1286,6 +1286,306 @@ #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 +#if XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + static void f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/7, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/8, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/9, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/10, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/11, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/1, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/2, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/3, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/4, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/5, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/6, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/7, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/8, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/9, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/10, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/11, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/1, /*nr=*/64, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/2, /*nr=*/64, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/3, /*nr=*/64, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/4, /*nr=*/64, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/5, /*nr=*/64, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast) +#endif // XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + + #if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) static void f32_gemm_minmax_ukernel_1x16__avx512f_broadcast(benchmark::State& state, const char* net) { GEMMBenchmark(state, @@ -2187,6 +2487,116 @@ BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm) + static void f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane) + static void f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld64(benchmark::State& state, const char* net) { GEMMBenchmark(state, xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld64, diff --git a/bench/qd8-f32-qc8w-gemm.cc b/bench/qd8-f32-qc8w-gemm.cc index 906d70576115..f6864d05cfee 100644 --- a/bench/qd8-f32-qc8w-gemm.cc +++ b/bench/qd8-f32-qc8w-gemm.cc @@ -724,6 +724,97 @@ #endif // XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64 +#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + static void qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_lane) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_lane) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/3, /*nr=*/8, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_lane) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_lane) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_lane) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_lane) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_lane) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_lane) +#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + + #if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM && XNN_ENABLE_ASSEMBLY static void qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch32_neondot_cortex_a55(benchmark::State& state, const char* net) { GEMMBenchmark(state, @@ -1348,6 +1439,306 @@ #endif // XNN_ENABLE_AVX512AMX && (XNN_ARCH_X86 || XNN_ARCH_X86_64) +#if XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + static void qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/1, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/2, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/3, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/4, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/5, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/6, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/7, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/8, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/9, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/10, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/11, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/6, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/11, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/1, /*nr=*/64, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/2, /*nr=*/64, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/3, /*nr=*/64, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/4, /*nr=*/64, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/5, /*nr=*/64, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni) +#endif // XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + + #if XNN_ENABLE_AVX512VNNI && (XNN_ARCH_X86 || XNN_ARCH_X86_64) static void qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni(benchmark::State& state, const char* net) { GEMMBenchmark(state, diff --git a/build_config/BUILD.bazel b/build_config/BUILD.bazel index 6352a43acdd4..c4052f8fa9c6 100644 --- a/build_config/BUILD.bazel +++ b/build_config/BUILD.bazel @@ -315,6 +315,21 @@ selects.config_setting_group( ], ) +selects.config_setting_group( + name = "x86_64", + match_any = [ + ":android_x86_64", + ":ios_x86_64", + ":linux_k8", + ":macos_x86_64", + ":macos_x86_64_legacy", + ":tvos_x86_64", + ":watchos_x86_64", + ":windows_x86_64_clang", + ":windows_x86_64_mingw", + ], +) + selects.config_setting_group( name = "riscv", match_any = [":linux_riscv64"], diff --git a/build_params.bzl b/build_params.bzl index d23edcd4163b..807a9fe0aafa 100644 --- a/build_params.bzl +++ b/build_params.bzl @@ -795,4 +795,21 @@ XNNPACK_PARAMS_FOR_ARCH = { ], extra_deps = [], # Extra deps for hexagon. ), + "amd64": _create_params( + cond = "//build_config:x86_64", + gcc_x86_copts = [ + "-mf16c", + "-mfma", + "-mavx512f", + "-mavx512cd", + "-mavx512bw", + "-mavx512dq", + "-mavx512vl", + "-mavx512vnni", + "-mgfni", + ], + msvc_x86_64_copts = ["/arch:AVX512"], + mingw_copts = ["-fno-asynchronous-unwind-tables"], + msys_copts = ["-fno-asynchronous-unwind-tables"], + ), } diff --git a/cmake/gen/aarch64_microkernels.cmake b/cmake/gen/aarch64_microkernels.cmake index 5ab1ee69c45b..caa50bb37336 100644 --- a/cmake/gen/aarch64_microkernels.cmake +++ b/cmake/gen/aarch64_microkernels.cmake @@ -123,6 +123,7 @@ SET(NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS src/f32-dwconv/f32-dwconv-9p4c-minmax-asm-aarch64-neonfma.S src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neon-ld128-acc2-prfm.S src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neon-ld128-acc2.S + src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc2-prfm.S src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc2.S src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc4-prfm.S @@ -134,15 +135,24 @@ SET(NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128-prfm.S src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128.S src/f32-gemm/gen/f32-gemm-1x12-minmax-asm-aarch64-neonfma-cortex-a53.S + src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld32.S + src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S + src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld32.S + src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S + src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld32.S src/f32-gemm/gen/f32-gemm-4x1-minmax-asm-aarch64-neonfma-ld64.S src/f32-gemm/gen/f32-gemm-4x1-minmax-asm-aarch64-neonfma-ld128.S src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-cortex-a75-prfm.S src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-cortex-a75.S src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-ld64.S + src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld64.S src/f32-gemm/gen/f32-gemm-4x12-minmax-asm-aarch64-neonfma-cortex-a53.S + src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld32.S src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-cortex-a75-prfm.S src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-cortex-a75.S + src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S + src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld32.S src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-cortex-a75.S src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-ld64.S src/f32-gemm/gen/f32-gemm-goi-1x8-minmax-asm-aarch64-neonfma-ld128-prfm.S @@ -228,6 +238,14 @@ SET(NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS src/f32-qc8w-gemm/gen/f32-qc8w-gemm-4x8-minmax-asm-aarch64-neonfma-ld64.S src/f32-qc8w-gemm/gen/f32-qc8w-gemm-4x8-minmax-asm-aarch64-neonfma-ld128.S src/f32-qc8w-gemm/gen/f32-qc8w-gemm-6x8-minmax-asm-aarch64-neonfma-ld64.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld32.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld32.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld32.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld32.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld32.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld32.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld32.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld32.S src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-asm-aarch64-neondot-ld64.S src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53.S src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-asm-aarch64-neondot-ld32.S diff --git a/cmake/gen/amd64_microkernels.cmake b/cmake/gen/amd64_microkernels.cmake new file mode 100644 index 000000000000..e11059c491aa --- /dev/null +++ b/cmake/gen/amd64_microkernels.cmake @@ -0,0 +1,70 @@ +# Copyright 2022 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Description: microkernel filename lists for amd64 +# +# Auto-generated file. Do not edit! +# Generator: tools/update-microkernels.py + + +SET(PROD_AMD64_ASM_MICROKERNEL_SRCS) + +SET(NON_PROD_AMD64_ASM_MICROKERNEL_SRCS + src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x64-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x32-minmax-asm-amd64-avx512vnni.S) + +SET(ALL_AMD64_ASM_MICROKERNEL_SRCS ${PROD_AMD64_ASM_MICROKERNEL_SRCS} + ${NON_PROD_AMD64_ASM_MICROKERNEL_SRCS}) diff --git a/cmake/gen/microkernels.cmake b/cmake/gen/microkernels.cmake index 6c42bc363421..bd8dcb7b7326 100644 --- a/cmake/gen/microkernels.cmake +++ b/cmake/gen/microkernels.cmake @@ -10,6 +10,7 @@ INCLUDE(cmake/gen/aarch32_microkernels.cmake) INCLUDE(cmake/gen/aarch64_microkernels.cmake) +INCLUDE(cmake/gen/amd64_microkernels.cmake) INCLUDE(cmake/gen/armsimd32_microkernels.cmake) INCLUDE(cmake/gen/avx_microkernels.cmake) INCLUDE(cmake/gen/avx2_microkernels.cmake) diff --git a/gemm_compiler/BUILD b/gemm_compiler/BUILD new file mode 100644 index 000000000000..1809c92142ea --- /dev/null +++ b/gemm_compiler/BUILD @@ -0,0 +1,80 @@ +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +load("@rules_python//python:py_binary.bzl", "py_binary") +load("@rules_python//python:py_library.bzl", "py_library") + +py_binary( + name = "generate_gemm_microkernels_main", + srcs = [ + "generate_gemm_microkernels_main.py", + ], + main = "generate_gemm_microkernels_main.py", + tags = [ + "notap", + ], + deps = [ + ":generate_gemm_microkernels", + ], +) + +py_library( + name = "generate_gemm_microkernels", + srcs = [ + "generate.py", + "generate_f32_gemm_microkernels.py", + "generate_qd8_f32_qc8w_gemm_microkernels.py", + ], + deps = [ + ":aarch64_arch_template", + ":aarch64_isa_templates", + ":x64_arch_template", + ":x64_isa_templates", + ], +) + +py_library( + name = "x64_isa_templates", + srcs = [ + "avx512f_template.py", + "avx512vnni_template.py", + "fma3_template.py", + ], +) + +py_library( + name = "x64_arch_template", + srcs = [ + "x64_template.py", + ], + deps = [ + ":base_architecture", + ], +) + +py_library( + name = "aarch64_isa_templates", + srcs = [ + "neondot_template.py", + "neonfma_template.py", + ], +) + +py_library( + name = "aarch64_arch_template", + srcs = [ + "aarch64_template.py", + ], + deps = [ + ":base_architecture", + ], +) + +py_library( + name = "base_architecture", + srcs = [ + "base_architecture.py", + ], +) diff --git a/gemm_compiler/aarch64_template.py b/gemm_compiler/aarch64_template.py new file mode 100644 index 000000000000..9364a5dc3387 --- /dev/null +++ b/gemm_compiler/aarch64_template.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from gemm_compiler import base_architecture as base_architecture + +"""All non SIMD features for aarch64.""" + + +class Aarch64(base_architecture.BaseArchitecture): + + def astride_register(self): + return 'x4' + + def kc_register(self): + return 'x2' + + def k_register(self): + return 'x20' + + def cm_stride_register(self): + return 'x7' + + def am_registers(self): + return [self.a_ptr_register()] + ['x9', 'x10', 'x11', 'x12', 'x21', 'x22'] + + def a_ptr_register(self): + return 'x3' + + def c_ptr_register(self): + return 'x6' + + def cm_registers(self): + return [self.c_ptr_register()] + ['x13', 'x14', 'x15', 'x19', 'x23', 'x24'] + + def w_ptr_register(self): + return 'x5' + + def min_register(self): + return 'v0' + + def max_register(self): + return 'v1' + + def nc_register(self): + return 'x1' + + def mr_register(self): + return 'x0' + + def tmp_gp_registers(self): + return ['x22', 'x23'] + + def dequantize(self, M, N, W): + return '' + + def adjust_kc(self): + return '' + + def register_map_byte(self, reg): + map = { + 'x0': 'x0', + 'x1': 'x1', + 'x2': 'x2', + 'x3': 'x3', + 'x4': 'x4', + 'x5': 'x5', + 'x6': 'x6', + 'x7': 'x7', + 'x8': 'x8', + 'x9': 'x9', + 'x10': 'x10', + 'x11': 'x11', + 'x12': 'x12', + 'x13': 'x13', + 'x14': 'x10', + 'x15': 'x15', + } + return map[reg] + + def register_map_dword(self, reg): + map = { + 'x0': 'q0', + 'x1': 'q1', + 'x2': 'q2', + 'x3': 'q3', + 'x4': 'q4', + 'x5': 'q5', + 'x6': 'q6', + 'x7': 'q7', + 'x8': 'q8', + 'x9': 'q9', + 'x10': 'q10', + 'x11': 'q11', + 'x12': 'q12', + 'x13': 'q13', + 'x14': 'q10', + 'x15': 'q15', + } + return map[reg] + + def function_name(self, M, N, isa): + return f'xnn_f32_gemm_minmax_ukernel_{M}x{N}__asm_aarch64_{isa}_lane\n' + + def quantization_params(self): + return '' + + def header(self, M, N, prefix, isa): + HEADER = '#include "xnnpack/assembly.h"\n\n' + + HEADER += 'BEGIN_FUNCTION ' + self.function_name(M, N, isa) + HEADER += """ + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13]\n""" + HEADER += self.quantization_params() + return HEADER + + def jump_to_label(self, label): + return f'b {label}\n' + + def read_a_registers(self, M): + return '' + + def inner_loop(self, M, N): + N_COUNT = N // self.n_step() + asm_string = '\ninner_loop:\n' + if 'before' in self.input_asm(): + asm_string += self.input_asm()['before'] + for mr in range(0, M): + for l in self.input_asm()['loop']: + asm_string += l.format( + AM_ptr=self.am_registers()[mr], + AM=self.a_registers(mr), + a_offset=self.k_register(), + ) + if 'after' in self.input_asm(): + asm_string += self.input_asm()['after'] + + # weights + if 'before' in self.weights_asm(): + asm_string += self.weights_asm()['before'] + for l in self.weights_asm()['loop_2']: + for nr in range(0, N_COUNT, 2): + asm_string += l.format( + W_ptr=self.w_ptr_register(), + W=self.w_registers()[nr], + W_1=self.w_registers()[nr + 1], + offset=self.register_bytes() * nr, + w_step=self.register_bytes() * N_COUNT, + ) + for l in self.weights_asm()['loop']: + if N_COUNT % 2 != 0: + asm_string += l.format( + W_ptr=self.w_ptr_register(), + W=self.w_registers()[nr], + offset=self.register_bytes() * nr, + w_step=self.register_bytes() * N_COUNT, + ) + if 'after' in self.weights_asm(): + asm_string += self.weights_asm()['after'].format( + W=self.w_ptr_register(), w_step=self.register_bytes() * N_COUNT + ) + + for l in self.compute_asm()['loop']: + for nr in range(0, N_COUNT): + for mr in range(0, M): + asm_string += l.format( + W=self.w_registers()[nr], + A=self.a_registers(mr), + ACC=self.acc_registers()[M * nr + mr], + ) + return asm_string + + def outer_loop_prepare(self, M, N): + return '' + + def input_output_register_setup(self, M): + registers = self.am_registers() + a_stride = self.astride_register() + c_stride = self.cm_stride_register() + a_base_ptr = self.a_ptr_register() + c_base_ptr = self.c_ptr_register() + # setup a{0}->a{M-1} registers + if M == 1: + return '' + asm_string = '# Setup and alias a & c pointers.\n' + asm_string += self.input_output_strides( + M=M, registers=self.am_registers(), stride=self.astride_register() + ) + + # setup c{0}->c{M-1} registers + asm_string += self.input_output_strides( + M=M, registers=self.cm_registers(), stride=self.cm_stride_register() + ) + + # Pre outer loop preparation + # asm_string += isa.outer_loop_prepare(M=M, N=N_COUNT, W=w_ptr_reg, accumulators=acc_registers) + + # if mr < MR + clamp_string, outer = self.clamp_inputs_and_outputs( + M, self.labels(), self.am_registers(), self.cm_registers() + ) + asm_string += clamp_string + return asm_string + + def input_output_strides(self, M, registers, stride): + INPUT_OUTPUT_REGISTER_SETUP = """add {aM}, {aM_1}, {STRIDE}\n""" + ret = '' + for mr in range(1, M): + ret += INPUT_OUTPUT_REGISTER_SETUP.format( + M=mr, + M_1=mr - 1, + aM=registers[mr], + aM_1=registers[mr - 1], + STRIDE=stride, + ) + return ret + + def clamp_inputs_and_outputs( + self, M, labels, input_registers, output_registers + ): + clamping = { + 'clamp': """ + cmp {mr_reg}, {M} + csel {AM_1}, {AM_0}, {AM_1}, LO + csel {CM_1}, {CM_0}, {CM_1}, LO + csel {AM_2}, {AM_1}, {AM_2}, LS + csel {CM_2}, {CM_1}, {CM_2}, LS\n""", + } + ret = '' + outer = M + # clamp a & c + end_index = M if (M % 2 == 1) else M - 1 + for mr in range(2, end_index, 2): + ret += clamping['clamp'].format( + mr_reg=self.mr_register(), + AM_0=input_registers[mr - 2], + AM_1=input_registers[mr - 1], + AM_2=input_registers[mr], + CM_0=output_registers[mr - 2], + CM_1=output_registers[mr - 1], + CM_2=output_registers[mr], + M=mr, + ) + if end_index != M: + ret += """ + cmp {mr_reg}, {M} + csel {AM_1}, {AM_0}, {AM_1}, LO + csel {CM_1}, {CM_0}, {CM_1}, LO\n""".format( + mr_reg=self.mr_register(), + AM_0=input_registers[end_index - 1], + AM_1=input_registers[end_index], + CM_0=output_registers[end_index - 1], + CM_1=output_registers[end_index], + M=end_index + 1, + ) + + return ret, outer + + def increment_ptr(self, ptr, step): + return f'add {ptr}, {ptr}, {step}\n' + + def initialize_k_register(self, reg): + return 'mov {reg}, {kc_register}\n'.format(reg=reg, kc_register=self.kc_register()) + + def cmp_k_and_jump_if_less(self, label): + kc_register = self.kc_register() + k_register = self.k_register() + return """subs {k_register}, {k_register}, 4 + bne {label}\n""".format( + label=label, k_register=k_register, kc_register=kc_register + ) + + def epilogue(self, M, N, isa): + restore_stack = """ +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION {function_name}""".format( + M=M, N=N, function_name=isa.function_name(M, N, isa.isa()) + ) + return restore_stack diff --git a/gemm_compiler/avx512f_template.py b/gemm_compiler/avx512f_template.py new file mode 100644 index 000000000000..f8f941b32ec5 --- /dev/null +++ b/gemm_compiler/avx512f_template.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from gemm_compiler import fma3_template as isa +from gemm_compiler import x64_template as arch + +"""All SIMD features for avx512f.""" + + +class Avx512F(isa.Fma3): + + def __init__(self): + pass # Empty constructor + + def isa(self): + return 'avx512f' + + def register_bytes(self): + return 64 + + def prefix(self): + return 'z' + + def a_registers(self, idx): + registers = ['zmm2', 'zmm3', 'zmm4', 'zmm5', 'zmm6'] + assert idx < len(registers) + return registers[idx] + + def w_registers(self): + return ['zmm10', 'zmm11', 'zmm12', 'zmm13'] + + def n_step(self): + return 16 + + def input_asm(self): + res = super().input_asm() + res['compute'] = self.compute_asm()['loop'] + return res + + def dequantize(self, M, N, W): + return '' + + def adjust_kc(self): + return '' + + def compute_asm(self): + c_asm = { + 'loop': ['vfmadd231ps z{ACC}, {A}, {W}\n'], + } + return c_asm + + def outer_loop_prepare(self, M, N): + return '' + + def inner_loop_spill_gp(self, M, N): + asm_string = '\ninner_loop:\n' + N_COUNT = N // self.n_step() + # weights + if 'before' in self.weights_asm(): + asm_string += self.weights_asm()['before'] + for l in self.weights_asm()['loop']: + for nr in range(0, N_COUNT): + asm_string += l.format( + W_ptr=self.w_ptr_register(), + W=self.w_registers()[nr], + offset=self.register_bytes() * nr, + w_step=self.register_bytes() * N_COUNT, + ) + + # input + if 'before' in self.input_asm(): + asm_string += self.input_asm()['before'] + if 'after' in self.input_asm(): + asm_string += self.input_asm()['after'] + if 'after' in self.weights_asm(): + asm_string += self.weights_asm()['after'].format( + W=self.w_ptr_register(), w_step=self.register_bytes() * N_COUNT + ) + + for mr in range(0, M): + for l in self.input_asm()['loop']: + asm_string += l.format( + AM_ptr=self.am_registers()[mr], + AM=self.a_registers(0), + a_offset=self.k_register(), + W=self.w_registers()[nr], + A=self.a_registers(0), + ACC=self.acc_registers()[M * nr + mr], + ) + for m in self.input_asm()['compute']: + for nr in range(0, N_COUNT): + asm_string += m.format( + W=self.w_registers()[nr], + A=self.a_registers(0), + ACC=self.acc_registers()[M * nr + mr], + ) + return asm_string + + def inner_loop_small_M_N(self, M, N): + N_COUNT = N // self.n_step() + asm_string = '\ninner_loop:\n' + # input + if 'before' in self.input_asm(): + asm_string += self.input_asm()['before'] + if 'after' in self.input_asm(): + asm_string += self.input_asm()['after'] + + # weights + if 'before' in self.weights_asm(): + asm_string += self.weights_asm()['before'] + for l in self.weights_asm()['loop']: + for nr in range(0, N_COUNT): + asm_string += l.format( + W_ptr=self.w_ptr_register(), + W=self.w_registers()[nr], + offset=self.register_bytes() * nr, + w_step=self.register_bytes() * N_COUNT, + ) + if 'after' in self.weights_asm(): + asm_string += self.weights_asm()['after'].format( + W=self.w_ptr_register(), w_step=self.register_bytes() * N_COUNT + ) + + for mr in range(0, M): + for l in self.input_asm()['loop']: + asm_string += l.format( + AM_ptr=self.am_registers()[mr], + AM=self.a_registers(mr), + a_offset=self.k_register(), + W=self.w_registers()[nr], + A=self.a_registers(mr), + ACC=self.acc_registers()[M * nr + mr], + ) + for m in self.input_asm()['compute']: + for nr in range(0, N_COUNT): + asm_string += m.format( + W=self.w_registers()[nr], + A=self.a_registers(mr), + ACC=self.acc_registers()[M * nr + mr], + ) + return asm_string + + def init_accumulators(self, M, N): + ret = '# Initialize accumulators with the biases.\n' + W = self.w_ptr_register() + accumulators = self.acc_registers() + bias = 'vmovaps z{ACC}, [{W} + {offset}]\n' + for nr in range(0, N): + ret += bias.format( + W=W, ACC=accumulators[nr * M], offset=self.register_bytes() * nr + ) + for nr in range(0, N): + for mr in range(1, M): + ret += self.copy_simd_register( + prefix=self.prefix(), + src=accumulators[M * nr], + dst=accumulators[M * nr + mr], + ) + return ret + + def copy_simd_register(self, prefix, src, dst): + return f'vmovaps {prefix}{dst}, {prefix}{src}\n' + + def store( + self, + M, + N, + ): + tmp_gp_regs = self.tmp_gp_registers() + accumulators = self.acc_registers() + cm_registers = self.cm_registers() + nc_reg = self.nc_register() + nc_lo = self.register_map_byte(nc_reg) + pop_c = M > self.max_M_before_spilling() + N_COUNT = N // self.n_step() + asm_string = '' + c_reg_offset = self.max_M_before_spilling() + if pop_c: + asm_string += '\n' + '# Pop output pointers from the stack.\n' + c_reg_offset = 0 + POP_C = 'mov {C_REG}, [rsp - {offset}]\n' + for mr in range(0, M): + sp_offset = 128 + (mr) * 16 + 8 + asm_string += POP_C.format(C_REG=cm_registers[mr], offset=sp_offset) + asm_string += """ + # Check whether full or partial store. + cmp {nc}, {n_step} + jl tail\n""".format(n_step=N, N_2=N // 2, nc=nc_reg) + for mr in range(0, M): + asm_string += """ + vmovups [{c_reg}], z{ACC}""".format( + ACC=accumulators[mr], c_reg=cm_registers[mr + c_reg_offset] + ) + for nr in range(1, N_COUNT): + asm_string += """ + vmovups [{c_reg} + {offset}], z{ACC}""".format( + ACC=accumulators[M * nr + mr], + c_reg=cm_registers[mr + c_reg_offset], + offset=self.register_bytes() * nr, + ) + asm_string += '\n' + for mr in range(0, M): + asm_string += 'add {cm}, {cn_stride}\n'.format( + cn_stride=N_COUNT * 64, cm=cm_registers[mr + c_reg_offset] + ) + if pop_c: + asm_string += '\n' + '# Write output pointers to the stack.\n' + POP_C = 'mov [rsp - {offset}], {C_REG}\n' + for mr in range(0, M): + sp_offset = 128 + (mr) * 16 + 8 + asm_string += POP_C.format(C_REG=cm_registers[mr], offset=sp_offset) + CHECK = """ + sub {nc}, {n_step} + jne outer_loop + jmp return\n""".format(n_step=N, nc=nc_reg) + asm_string += CHECK + + asm_string += '\ntail:' + if N == 64: + asm_string += """ + mov {tmp1}, -1 + sal {tmp1}, {nc_lo} + not {tmp1} + kmovw k1, {tmp1_lo} + shr {tmp1}, 16 + kmovw k2, {tmp1_lo} + shr {tmp1}, 16 + kmovw k3, {tmp1_lo} + shr {tmp1}, 16 + kmovw k4, {tmp1_lo}\n + """.format( + nc_reg=nc_reg, + tmp1=tmp_gp_regs[1], + tmp1_lo=self.register_map_dword(tmp_gp_regs[1]), + nc_lo='cl', + ACC=accumulators[0], + c_reg=cm_registers[0], + ) + for mr in range(0, M): + asm_string += 'vmovups ZMMWORD PTR [{c_reg}]{{k1}}, z{ACC}\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr + c_reg_offset] + ) + asm_string += ( + 'vmovups ZMMWORD PTR [{c_reg} + 64]{{k2}}, z{ACC}\n'.format( + ACC=accumulators[mr + M], c_reg=cm_registers[mr + c_reg_offset] + ) + ) + asm_string += ( + 'vmovups ZMMWORD PTR [{c_reg} + 128]{{k3}}, z{ACC}\n'.format( + ACC=accumulators[mr + 2 * M], + c_reg=cm_registers[mr + c_reg_offset], + ) + ) + asm_string += ( + 'vmovups ZMMWORD PTR [{c_reg} + 192]{{k4}}, z{ACC}\n'.format( + ACC=accumulators[mr + 3 * M], + c_reg=cm_registers[mr + c_reg_offset], + ) + ) + elif N == 32: + asm_string += """ + mov {tmp1_lo}, -1 + sal {tmp1_lo}, {nc_lo} + not {tmp1_lo} + kmovw k1, {tmp1_lo} + shr {tmp1_lo}, 16 + kmovw k2, {tmp1_lo}\n""".format( + nc_reg=nc_reg, + tmp1_lo=self.register_map_dword(tmp_gp_regs[1]), + nc_lo='cl', + ACC=accumulators[0], + c_reg=cm_registers[0], + ) + for mr in range(0, M): + asm_string += 'vmovups ZMMWORD PTR [{c_reg}]{{k1}}, z{ACC}\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr + c_reg_offset] + ) + asm_string += ( + 'vmovups ZMMWORD PTR [{c_reg} + 64]{{k2}}, z{ACC}\n'.format( + ACC=accumulators[mr + M], c_reg=cm_registers[mr + c_reg_offset] + ) + ) + else: + asm_string += """ + mov {tmp1_lo}, -1 + sal {tmp1_lo}, {nc_lo} + not {tmp1_lo} + kmovw k1, {tmp1_lo}\n""".format( + nc_reg=nc_reg, + tmp1_lo=self.register_map_dword(tmp_gp_regs[1]), + nc_lo='cl', + ACC=accumulators[0], + c_reg=cm_registers[0 + c_reg_offset], + ) + for mr in range(0, M): + asm_string += 'vmovups ZMMWORD PTR [{c_reg}]{{k1}}, z{ACC}\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr + c_reg_offset] + ) + + return asm_string diff --git a/gemm_compiler/avx512vnni_template.py b/gemm_compiler/avx512vnni_template.py new file mode 100644 index 000000000000..a7da45726b4f --- /dev/null +++ b/gemm_compiler/avx512vnni_template.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from gemm_compiler import avx512f_template as isa + +"""All SIMD features for avx512vnni.""" + + +class Avx512Vnni(isa.Avx512F): + + def isa(self): + return 'avx512vnni' + + def a_registers(self, idx): + return 'zmm2' + + def scale_registers(self): + return ['zmm10', 'zmm11', 'zmm2', 'zmm3'] + + def w_registers(self): + return ['zmm6', 'zmm7', 'zmm8', 'zmm9'] + + def acc_registers(self): + return [ + 'mm5', + 'mm12', + 'mm14', + 'mm15', + 'mm16', + 'mm17', + 'mm18', + 'mm19', + 'mm20', + 'mm21', + 'mm22', + 'mm23', + 'mm24', + 'mm25', + 'mm26', + 'mm27', + 'mm28', + 'mm29', + 'mm30', + 'mm4', + 'mm8', + 'mm9', + ] + + def function_name(self, M, N, isa): + return f'xnn_qd8_f32_qc8w_gemm_minmax_ukernel_{M}x{N}c4__asm_amd64_{isa}\n' + + def zp_scale(self, pos): + regs = ['10', '11'] + return regs[pos] + + # kc = round_up_po2(kc, channels) + def adjust_kc(self): + channels = 4 + ret = """ + add {kc_reg}, {channels} + and {kc_reg}, {neg_channels}\n""".format( + kc_reg=self.kc_register(), channels=channels - 1, neg_channels=-channels + ) + return ret + + def quantization_params(self): + return """ + mov {quantization_params_reg}, [rsp + 88] + """.format(quantization_params_reg=self.quantization_params_register()) + + def quantization_params_register(self): + return self.k_register() + + def input_asm(self): + in_asm = { + 'loop': [ + 'vpbroadcastd {AM}, [{AM_ptr} + {a_offset}]\n', + ], + 'compute': ['vpdpbusd z{ACC}, {A}, {W}\n'], + } + return in_asm + + def weights_asm(self): + w_asm = { + 'loop': [ + 'vmovaps {W}, [{W_ptr} + {offset}]\n', + ], + 'after': 'add {W}, {w_step}\n', + } + return w_asm + + def compute_asm(self): + c_asm = { + 'loop': ['vpdpbusd z{ACC}, {A}, {W}\n'], + } + return c_asm + + def dequantize(self, M, N, W): + accumulators = self.acc_registers() + ret = '' + ret += '\n# Convert from int32 to float.\n' + for nr in range(0, N * M): + ret += 'vcvtdq2ps z{ACC}, z{ACC}\n'.format(ACC=accumulators[nr]) + ret += '# Load quantization_params pointer from stack\n' + ret += 'mov {quantization_params_reg}, [rsp + {offset}]\n'.format( + quantization_params_reg=self.quantization_params_register(), + offset=self.stack_size(M) + 88, + ) + for nr in range(0, N): + for mr in range(0, M): + ret += ( + 'vmulps z{ACC}, z{ACC}, DWORD PTR [{quantization_params_reg} +' + ' {offset}]{{1to16}}\n'.format( + ACC=accumulators[nr * M + mr], + offset=4 + mr * 8, + quantization_params_reg=self.quantization_params_register(), + ) + ) + output_scale = 'vmovaps {W_SCALE}, [{W} + {offset}]\n' + # output scales + for nr in range(0, N): + ret += output_scale.format( + W=W, + offset=self.register_bytes() * nr, + W_SCALE=self.scale_registers()[nr], + ) + ret += self.increment_ptr(ptr=W, step=self.register_bytes() * N) + # biases + for nr in range(0, N): + ret += output_scale.format( + W=W, offset=self.register_bytes() * nr, W_SCALE=self.w_registers()[nr] + ) + ret += self.increment_ptr(ptr=W, step=self.register_bytes() * N) + # Intel gets points here for its fma instructions which can accumulate into + # any of the registers. For once, Intel has saner instructions than Arm. + for nr in range(0, N): + for mr in range(0, M): + ret += 'vfmadd213ps z{ACC}, {SCALE}, {BIAS}\n'.format( + ACC=accumulators[nr * M + mr], + SCALE=self.scale_registers()[nr], + BIAS=self.w_registers()[nr], + ) + + return ret + + def outer_loop_prepare(self, M, N): + W = self.w_ptr_register() + accumulators = self.acc_registers() + # outside the outer loop + zp_scale_load_push = ( + """mov {tmp_reg}, [{quantization_params_reg} + {zp_offset}] + vpbroadcastd {tmp_s_reg}, {tmp_reg} + vmovups zmmword ptr [rsp + {offset}], {tmp_s_reg}\n""" + ) + ret = '\n# Load quantization params pointer from stack\n' + ret += 'mov {quantization_params_reg}, [rsp + {offset}]\n'.format( + quantization_params_reg=self.quantization_params_register(), + offset=self.stack_size(M) + 88, + ) + for mr in range(0, M, 1): + ret += zp_scale_load_push.format( + tmp_reg=self.register_map_dword(self.tmp_gp_registers()[0]), + quantization_params_reg=self.quantization_params_register(), + tmp_s_reg=self.w_registers()[0], + offset=464 + mr * 64, + zp_offset=mr * 8, + ) + return ret + + def init_accumulators(self, M, N): + ret = '# Initialize accumulators with k_sum * input zero point.\n' + accumulators = self.acc_registers() + W = self.w_ptr_register() + + ksum_x16 = 'vmovaps {KSUM}, [{W} + {offset}]\n' + vksum = 'vpmulld z{ACC}, {KSUM}, ZMMWORD PTR [rsp + {offset}]\n' + + for nr in range(0, N): + ret += ksum_x16.format( + W=W, KSUM=self.w_registers()[nr], offset=self.register_bytes() * nr + ) + for nr in range(0, N): + for mr in range(0, M): + ret += vksum.format( + ACC=accumulators[nr * M + mr], + KSUM=self.w_registers()[nr], + pos=int((mr % 2) * 2), + offset=464 + mr * 64, + ) + + return ret + + def stack_size(self, M): + return 464 + M * 64 diff --git a/gemm_compiler/base_architecture.py b/gemm_compiler/base_architecture.py new file mode 100644 index 000000000000..9133ef58eb54 --- /dev/null +++ b/gemm_compiler/base_architecture.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import abstractmethod + +from gemm_compiler import base_architecture as base_architecture + +"""Base architecture for GEMM microkernel generation""" + + +class BaseArchitecture: + + def __init__(self): + pass # Empty constructor + + def labels(self): + return [ + 'zero', + 'one', + 'two', + 'three', + 'four', + 'five', + 'six', + 'seven', + 'eight', + 'nine', + 'ten', + 'eleven', + ] + + @abstractmethod + def astride_register(self): + """Returns the register containing a_stride.""" + raise NotImplementedError + + @abstractmethod + def kc_register(self): + """Returns the register containing kc, the number of channels (the reduction dimensions).""" + raise NotImplementedError + + @abstractmethod + def k_register(self): + """Returns the register containing k, the current channel being processed.""" + raise NotImplementedError + + @abstractmethod + def cm_stride_register(self): + """Returns the register containing cm_stride.""" + raise NotImplementedError + + @abstractmethod + def am_registers(self): + """Returns the registers containing the pointers to each row of A (LHS).""" + raise NotImplementedError + + @abstractmethod + def a_ptr_register(self): + """Returns the register containing the A pointer.""" + raise NotImplementedError + + @abstractmethod + def c_ptr_register(self): + """Returns the register containing the C pointer.""" + raise NotImplementedError + + @abstractmethod + def cm_registers(self): + """Returns the registers containing the pointers to each row of C (Output).""" + raise NotImplementedError + + @abstractmethod + def acc_registers(self): + """Returns the accumulator registers.""" + raise NotImplementedError + + @abstractmethod + def w_ptr_register(self): + """Returns the register containing the weight's pointer.""" + raise NotImplementedError + + @abstractmethod + def min_register(self): + """Returns the register containing the min value for clamping.""" + raise NotImplementedError + + @abstractmethod + def max_register(self): + """Returns the register containing the max value for clamping.""" + raise NotImplementedError + + @abstractmethod + def nc_register(self): + """Returns the register containing nc, the number of output rows processed per iteration.""" + raise NotImplementedError + + @abstractmethod + def mr_register(self): + """Returns the register containing mr, the number of input rows processed per kernel call.""" + raise NotImplementedError + + @abstractmethod + def tmp_gp_registers(self): + """Returns some general purpose registers which may be used for storing temporary data.""" + raise NotImplementedError + + @abstractmethod + def jump_to_label(self, label): + """Jump to the given label.""" + raise NotImplementedError + + @abstractmethod + def function_name(self, M, N, isa): + """Returns the microkernel name.""" + raise NotImplementedError + + @abstractmethod + def header(self, M, N, prefix, isa): + """Returns the assembly header.""" + raise NotImplementedError + + @abstractmethod + def input_output_register_setup(self, M): + """Setup the input (A) and output (C) registers.""" + raise NotImplementedError + + @abstractmethod + def max_M_before_spilling(self): + """How large can M be before spilling A and C registers to the stack.""" + raise NotImplementedError + + @abstractmethod + def read_a_registers(self, M): + """Read the A registers from the stack.""" + raise NotImplementedError + + @abstractmethod + def increment_ptr(self, ptr, step): + """Increment the given pointer by step bytes.""" + raise NotImplementedError + + @abstractmethod + def initialize_k_register(self, reg): + """Initialized the given general purpose register for inner loop control.""" + raise NotImplementedError + + @abstractmethod + def cmp_k_and_jump_if_less(self, label): + """If k is less than kc, then do another iteration of the inner loop.""" + raise NotImplementedError + + @abstractmethod + def epilogue(self, M, N, isa): + """Returns the function epilogue.""" + raise NotImplementedError + + @abstractmethod + def inner_loop(self, M, N): + """Returns the assemebly for the microkernel's inner loop.""" + raise NotImplementedError diff --git a/gemm_compiler/fma3_template.py b/gemm_compiler/fma3_template.py new file mode 100644 index 000000000000..2d59b4249c2e --- /dev/null +++ b/gemm_compiler/fma3_template.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from gemm_compiler import x64_template as arch + +"""All SIMD features for fma3.""" + + +class Fma3(arch.X64): + + def __init__(self): + pass # Empty constructor + + def isa(self): + return 'fma3' + + def register_bytes(self): + return 32 + + def prefix(self): + return 'y' + + def a_registers(self, idx): + registers = ['ymm2', 'ymm3', 'ymm4', 'ymm5'] + assert idx < len(registers) + return registers[idx] + + def w_registers(self): + return ['ymm14', 'ymm15'] + + def input_asm(self): + in_asm = { + 'loop': [ + 'vbroadcastss {AM}, DWORD PTR [{AM_ptr} + {a_offset}]\n', + ] + } + return in_asm + + def weights_asm(self): + w_asm = { + 'loop': [ + 'vmovaps {W}, [{W_ptr} + {offset}]\n', + ], + 'after': 'add {W}, {w_step}\n', + } + return w_asm + + def compute_asm(self): + c_asm = { + 'loop': ['vfmadd231ps y{ACC}, {A}, {W}\n'], + } + return c_asm + + def load_bias(self): + return 'vmovaps y{ACC}, [{W} + {offset}]\n' + + def copy_simd_register(self, prefix, src, dst): + return f'vmovaps {prefix}{dst}, {prefix}{src}\n' + + def clamp_min(self, reg, prefix): + max_reg = self.max_register() + return f'vminps {prefix}{reg}, {prefix}{max_reg}, {prefix}{reg}\n' + + def clamp_max(self, reg, prefix): + min_reg = self.min_register() + return f'vmaxps {prefix}{reg}, {prefix}{min_reg}, {prefix}{reg}\n' + + def store( + self, + M, + N, + ): + accumulators = self.acc_registers() + cm_registers = self.cm_registers() + nc_reg = self.nc_register() + nc_lo = self.register_map_byte(nc_reg) + N_STEP = 8 + N_COUNT = N // N_STEP + asm_string = """ + cmp {nc}, {n_step} + jl tail_{N_2} + """.format(n_step=N, N_2=N // 2, nc=nc_reg) + for mr in range(0, M): + asm_string += 'vmovups [{c_reg}], y{ACC}\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr] + ) + for nr in range(1, N_COUNT): + asm_string += 'vmovups [{c_reg} + {offset}], y{ACC}\n'.format( + ACC=accumulators[M * nr + mr], + c_reg=cm_registers[mr], + offset=isa.register_bytes() * nr, + ) + for mr in range(0, M): + asm_string += 'add {cm}, {cn_stride}\n'.format( + cn_stride=cn_stride_reg, cm=cm_registers[mr] + ) + CHECK = """ + sub {nc}, {n_step} + jne {OUTER} + jmp return""".format(n_step=N, nc=nc_reg, OUTER=labels[M]) + asm_string += CHECK + N = N // 2 + if N * 2 > N_STEP: + asm_string += """ + tail_{N}: + test {nc_lo}, {N} + jz tail_{N_2}\n""".format(N=N, N_2=N // 2, nc_lo=nc_lo) + for mr in range(0, M): + asm_string += 'vmovups [{c_reg}], y{ACC}\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr] + ) + N_COUNT = N // N_STEP + for nr in range(1, N_COUNT): + for mr in range(0, M): + asm_string += 'vmovups [{c_reg} + {offset}], y{ACC}\n'.format( + ACC=accumulators[M * nr + mr], + c_reg=cm_registers[mr], + offset=isa.register_bytes() * nr, + ) + for mr in range(0, M): + asm_string += 'vmovaps y{ACC0}, y{ACC1}\n'.format( + ACC0=accumulators[mr], ACC1=accumulators[mr + M * nr] + ) + for mr in range(0, M): + asm_string += 'add {cm}, 32\n'.format( + cn_stride=cn_stride_reg, cm=cm_registers[mr] + ) + asm_string += """ +tail_4: + test {nc_lo}, 4 + jz tail_2\n""".format(nc_lo=nc_lo) + for mr in range(0, M): + asm_string += 'vmovups [{c_reg}], x{ACC}\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr] + ) + for mr in range(0, M): + asm_string += 'add {c_reg}, 16\n'.format(c_reg=cm_registers[mr]) + for mr in range(0, M): + asm_string += 'vextractf128 x{ACC}, y{ACC}, 1\n'.format( + ACC=accumulators[mr] + ) + asm_string += """ +tail_2: + test {nc_lo}, 2 + jz tail_1\n""".format(nc_lo=nc_lo) + for mr in range(0, M): + asm_string += 'vmovlps QWORD PTR [{c_reg}], x{ACC}\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr] + ) + for mr in range(0, M): + asm_string += 'add {c_reg}, 8\n'.format(c_reg=cm_registers[mr]) + for mr in range(0, M): + asm_string += 'vmovhlps x{ACC}, x{ACC}, x{ACC}\n'.format( + ACC=accumulators[mr] + ) + asm_string += """ +tail_1: + test {nc_lo}, 1 + jz return\n""".format(nc_lo=nc_lo) + for mr in range(0, M): + asm_string += 'vmovss DWORD PTR [{c_reg}], x{ACC}\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr] + ) + + return asm_string diff --git a/gemm_compiler/generate.py b/gemm_compiler/generate.py new file mode 100644 index 000000000000..d5a99898f994 --- /dev/null +++ b/gemm_compiler/generate.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys + +from gemm_compiler import base_architecture + +"""Shared logic for assembly gemm microkernel generation.""" + + +def generate_gemm_microkernel( + M: int, N: int, isa: base_architecture.BaseArchitecture, output_file: str +): + elements_per_register = isa.n_step() + num_horizontal_registers = int(N / elements_per_register) + asm_string = isa.header(M, N, isa.prefix(), isa.isa()) + + k_register = isa.k_register() + acc_registers = isa.acc_registers() + w_ptr_reg = isa.w_ptr_register() + + # adjust inner loop + asm_string += isa.adjust_kc() + + # setup a{1}->a{M-1} & c{1]->c{M-1}registers + asm_string += isa.input_output_register_setup( + M=M, + ) + + # Pre outer loop preparation + asm_string += isa.outer_loop_prepare(M=M, N=num_horizontal_registers) + + # the outer loop label + asm_string += '\nouter_loop:\n' + asm_string += '# Initialize k counter.\n' + asm_string += isa.initialize_k_register(k_register) + + # Read a registers from the stack if required + asm_string += isa.read_a_registers(M=M) + + # Initialize accumulators + asm_string += isa.init_accumulators( + M=M, + N=num_horizontal_registers, + ) + asm_string += isa.increment_ptr( + ptr=w_ptr_reg, step=isa.register_bytes() * num_horizontal_registers + ) + + # inner loop + asm_string += isa.inner_loop(M, N) + + # loop counter + asm_string += isa.cmp_k_and_jump_if_less(label='inner_loop') + + asm_string += isa.dequantize(M=M, N=num_horizontal_registers, W=w_ptr_reg) + + # min/max clamping + asm_string += '# Min/max clamping..\n' + for nr in range(0, num_horizontal_registers): + for mr in range(0, M): + asm_string += isa.clamp_min( + reg=acc_registers[M * nr + mr], prefix=isa.prefix() + ) + for nr in range(0, num_horizontal_registers): + for mr in range(0, M): + asm_string += isa.clamp_max( + reg=acc_registers[M * nr + mr], prefix=isa.prefix() + ) + + # store + asm_string += isa.store( + M=M, + N=N, + ) + + asm_string += isa.epilogue(M, N, isa) + + # Correctly indent the generated assembly. + lines = asm_string.splitlines() + stripped_lines = [line.lstrip() for line in lines] + # Indent all lines that are not labels. + stripped_lines = [ + ' ' + line + if not (line.endswith(':') or 'FUNCTION' in line or 'include' in line) + else line + for line in stripped_lines + ] + # Strip indentation from empty lines. + stripped_lines = ['' if line.isspace() else line for line in stripped_lines] + asm_string = '\n'.join(stripped_lines) + + with open(output_file, 'w') as f: + f.write(asm_string) diff --git a/gemm_compiler/generate_f32_gemm_microkernels.py b/gemm_compiler/generate_f32_gemm_microkernels.py new file mode 100644 index 000000000000..d984467ffb1b --- /dev/null +++ b/gemm_compiler/generate_f32_gemm_microkernels.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys + +from gemm_compiler import avx512f_template +from gemm_compiler import fma3_template +from gemm_compiler import generate +from gemm_compiler import neonfma_template + +"""Generates f32 assembly gemm microkernels.""" + + +output_base = 'src/f32-gemm/gen/' + + +def generate_f32_gemm_microkernels(): + if '/bazel-out/' in os.getcwd(): + os.chdir(os.environ['BUILD_WORKING_DIRECTORY']) + + for nr in range(16, 33, 16): + for mr in range(1, 12): + generate.generate_gemm_microkernel( + M=mr, + N=nr, + isa=avx512f_template.Avx512F(), + output_file=os.path.join( + output_base, + f'f32-gemm-{mr}x{nr}-minmax-asm-amd64-avx512f-broadcast.S', + ), + ) + + # not enough SIMD registers to go above 5x64 + for mr in range(1, 6): + generate.generate_gemm_microkernel( + M=mr, + N=64, + isa=avx512f_template.Avx512F(), + output_file=os.path.join( + output_base, + f'f32-gemm-{mr}x64-minmax-asm-amd64-avx512f-broadcast.S', + ), + ) + + for nr in range(8, 17, 8): + for mr in range(1, 6): + generate.generate_gemm_microkernel( + M=mr, + N=nr, + isa=neonfma_template.NeonFma(), + output_file=os.path.join( + output_base, + f'f32-gemm-{mr}x{nr}-minmax-asm-aarch64-neonfma-ld32.S', + ), + ) diff --git a/gemm_compiler/generate_gemm_microkernels_main.py b/gemm_compiler/generate_gemm_microkernels_main.py new file mode 100644 index 000000000000..46dd710fcf39 --- /dev/null +++ b/gemm_compiler/generate_gemm_microkernels_main.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys + +from gemm_compiler import generate_f32_gemm_microkernels as f32 +from gemm_compiler import generate_qd8_f32_qc8w_gemm_microkernels as qd8_f32_qc8w + +"""Generates all assembly gemm microkernels.""" + + +def main(args): + + f32.generate_f32_gemm_microkernels() + qd8_f32_qc8w.generate_qd8_f32_qc8w_gemm_microkernels() + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/gemm_compiler/generate_qd8_f32_qc8w_gemm_microkernels.py b/gemm_compiler/generate_qd8_f32_qc8w_gemm_microkernels.py new file mode 100644 index 000000000000..a2c5ab13a59f --- /dev/null +++ b/gemm_compiler/generate_qd8_f32_qc8w_gemm_microkernels.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys + +from gemm_compiler import avx512vnni_template +from gemm_compiler import generate +from gemm_compiler import neondot_template + +"""Generates qd8-f32-qc8w assembly gemm microkernels.""" + + +output_base = 'src/qd8-f32-qc8w-gemm/gen/' + + +def generate_qd8_f32_qc8w_gemm_microkernels(): + if '/bazel-out/' in os.getcwd(): + os.chdir(os.environ['BUILD_WORKING_DIRECTORY']) + + for nr in range(8, 17, 8): + for mr in range(1, 5): + generate.generate_gemm_microkernel( + M=mr, + N=nr, + isa=neondot_template.NeonDot(), + output_file=os.path.join( + output_base, + f'qd8-f32-qc8w-gemm-{mr}x{nr}-minmax-asm-aarch64-neondot-ld32.S', + ), + ) + + for nr in range(16, 33, 16): + for mr in range(1, 12): + generate.generate_gemm_microkernel( + M=mr, + N=nr, + isa=avx512vnni_template.Avx512Vnni(), + output_file=os.path.join( + output_base, + f'qd8-f32-qc8w-gemm-{mr}x{nr}-minmax-asm-amd64-avx512vnni.S', + ), + ) + + # not enough SIMD registers to go above 5x64 + for mr in range(1, 6): + generate.generate_gemm_microkernel( + M=mr, + N=64, + isa=avx512vnni_template.Avx512Vnni(), + output_file=os.path.join( + output_base, + f'qd8-f32-qc8w-gemm-{mr}x64-minmax-asm-amd64-avx512vnni.S', + ), + ) diff --git a/gemm_compiler/neondot_template.py b/gemm_compiler/neondot_template.py new file mode 100644 index 000000000000..70400d9740a7 --- /dev/null +++ b/gemm_compiler/neondot_template.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from gemm_compiler import neonfma_template as isa + +"""All SIMD features for Aarch64 neondot.""" + + +class NeonDot(isa.NeonFma): + + def isa(self): + return 'neondot' + + def a_registers(self, idx): + registers = ['2', '3', '4', '5'] + assert idx < len(registers) + return registers[idx] + + def w_registers(self): + return ['6', '7', '8', '9'] + + def acc_registers(self): + return [ + '12', + '13', + '14', + '15', + '16', + '17', + '18', + '19', + '20', + '21', + '22', + '23', + '24', + '25', + '26', + '27', + '28', + '29', + '30', + ] + + def function_name(self, M, N, isa): + return f'xnn_qd8_f32_qc8w_gemm_minmax_ukernel_{M}x{N}c4__asm_aarch64_{isa}_lane\n' + + def zp_scale(self, pos): + regs = ['10', '11'] + return regs[pos] + + # kc = round_up_po2(kc, channels) + def adjust_kc(self): + channels = 4 + m = pow(2, 64) - channels + not_channels = f'0x{m:016X}' + ret = '# Round kc up to channels.\n' + ret += """add {kc_reg}, {kc_reg}, #{channels} + and {kc_reg}, {kc_reg}, #{not_channels}\n\n""".format( + kc_reg=self.kc_register(), + channels=channels - 1, + not_channels=not_channels, + ) + return ret + + def quantization_params(self): + return """ldr {quantization_params_reg}, [sp, 16]\n""".format( + quantization_params_reg=self.quantization_params_register() + ) + + def quantization_params_register(self): + return 'x24' + + def compute_asm(self): + c_asm = { + 'loop': ['sdot v{ACC}.4s, v{W}.16b, v{A}.4b[0]\n'], + } + return c_asm + + def dequantize(self, M, N, W): + accumulators = self.acc_registers() + ret = '\n# Convert from int32 to float.\n' + for nr in range(0, N * M): + ret += 'scvtf v{ACC}.4s, v{ACC}.4s\n'.format(ACC=accumulators[nr]) + ret += '# Multiply by input scale.\n' + for nr in range(0, N): + for mr in range(0, M): + ret += 'fmul v{ACC}.4s, v{ACC}.4s, v{zp_scale}.s[{pos}]\n'.format( + ACC=accumulators[nr * M + mr], + zp_scale=self.zp_scale(mr // 2), + pos=int((mr % 2) * 2) + 1, + ) + ret += '# Load weights scale.\n' + output_scale_pair = 'ldp q{W_SCALE_0}, q{W_SCALE_1}, [{W}, {offset}]\n' + # output scales + for nr in range(0, N, 2): + ret += output_scale_pair.format( + W=W, + offset=self.register_bytes() * nr, + W_SCALE_0=self.a_registers(nr), + W_SCALE_1=self.a_registers(nr + 1), + ) + ret += self.increment_ptr(ptr=W, step=self.register_bytes() * N) + # biases + ret += '# Load biases.\n' + for nr in range(0, N, 2): + ret += output_scale_pair.format( + W=W, + offset=self.register_bytes() * nr, + W_SCALE_0=self.w_registers()[nr], + W_SCALE_1=self.w_registers()[nr + 1], + ) + ret += 'add {W}, {W}, {increment}\n'.format( + W=W, increment=self.register_bytes() * N + ) + # do mul + add here instead of fmla. + # fmla accumulaltes into the additional term, in this case the bias. This + # means that the bias must be copied before the fmla. + # From the Cortex X1 optimization guide, fmov takes 1 cycle with a + # throughput of 4 and fmla takes 4 cycles with a throughput of 4. This means + # 5 cycles for four movs + fmla. fadd takes 2 cycles with a throughput of 4 + # and fmul takes 3 cycles with a throughput of 4, for a total of 5 cycles + # for 4 results. + ret += "# Multiply by weight's scale.\n" + for nr in range(0, N): + for mr in range(0, M): + ret += 'fmul v{ACC}.4s, v{ACC}.4s, v{SCALE}.4s\n'.format( + ACC=accumulators[nr * M + mr], SCALE=self.a_registers(nr) + ) + ret += '# Add bias.\n' + for nr in range(0, N): + for mr in range(0, M): + ret += 'fadd v{ACC}.4s, v{ACC}.4s, v{BIAS}.4s\n'.format( + ACC=accumulators[nr * M + mr], BIAS=self.w_registers()[nr] + ) + + return ret + + def init_accumulators(self, M, N): + ret = '# Initialize accumulators with k_sum * input zero point.\n' + accumulators = self.acc_registers() + W = self.w_ptr_register() + zp_scale_x2 = 'ldr q{zp_scale}, [{quantization_params_reg}]\n' + zp_scale_x4 = ( + 'ldp q{zp_scale_0}, q{zp_scale_1}, [{quantization_params_reg}]\n' + ) + ksum_x8 = 'ldp q{KSUM_0}, q{KSUM_1}, [{W}, {offset}]\n' + vksum = 'mul v{ACC}.4s, v{KSUM}.4s, v{zp_scale}.s[{pos}]\n' + + mr = 0 + for mr in range(0, M - 1, 4): + ret += zp_scale_x4.format( + quantization_params_reg=self.quantization_params_register(), + zp_scale_0=self.zp_scale(mr), + zp_scale_1=self.zp_scale(mr + 1), + ) + if M % 2 == 1: + ret += zp_scale_x2.format( + quantization_params_reg=self.quantization_params_register(), + zp_scale=self.zp_scale(mr), + ) + for nr in range(0, N - 1, 2): + ret += ksum_x8.format( + W=W, + KSUM_0=self.a_registers(nr), + KSUM_1=self.a_registers(nr + 1), + offset=self.register_bytes() * nr, + ) + for nr in range(0, N): + for mr in range(0, M): + ret += vksum.format( + ACC=accumulators[nr * M + mr], + KSUM=self.a_registers(nr), + zp_scale=self.zp_scale(mr // 2), + pos=int((mr % 2) * 2), + ) + + return ret diff --git a/gemm_compiler/neonfma_template.py b/gemm_compiler/neonfma_template.py new file mode 100644 index 000000000000..cca92ed2611a --- /dev/null +++ b/gemm_compiler/neonfma_template.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from gemm_compiler import aarch64_template as arch + +"""All SIMD features for Aarch64 neondot.""" + + +class NeonFma(arch.Aarch64): + + def n_step(self): + return 4 + + def isa(self): + return 'neonfma' + + def register_bytes(self): + return 16 + + def prefix(self): + return 'v' + + def acc_registers(self): + return [ + '11', + '12', + '13', + '14', + '15', + '16', + '17', + '18', + '19', + '20', + '21', + '22', + '23', + '24', + '25', + '26', + '27', + '28', + '29', + '30', + ] + + def a_registers(self, idx): + registers = ['2', '3', '4', '5', '6'] + assert idx < len(registers) + return registers[idx] + + def w_registers(self): + return ['7', '8', '9', '10'] + + def input_asm(self): + in_asm = { + 'loop': [ + 'ldr s{AM}, [{AM_ptr}], 4\n', + ] + } + return in_asm + + def weights_asm(self): + w_asm = { + 'loop': [ + 'ldr q{W}, [{W_ptr}], 16\n', + ], + 'loop_2': [ + 'ldp q{W}, q{W_1}, [{W_ptr}], 32\n', + ], + } + return w_asm + + def compute_asm(self): + c_asm = { + 'loop': ['fmla v{ACC}.4s, v{W}.4s, v{A}.s[0]\n'], + } + return c_asm + + def init_accumulators(self, M, N): + ret = '# Initialize accumulators with the biases.\n' + accumulators = self.acc_registers() + W = self.w_ptr_register() + single_bias = 'ldr q{ACC}, [{W}, {offset}]\n' + pair_bias = 'ldp q{ACC}, q{ACC_1}, [{W}, {offset}]\n' + + for nr in range(0, N - 1, 2): + ret += pair_bias.format( + W=W, + ACC=accumulators[nr * M], + ACC_1=accumulators[nr * M + M], + offset=self.register_bytes() * nr, + ) + if N % 2 != 0: + ret += single_bias.format( + W=W, + ACC=accumulators[(N - 1) * M], + offset=self.register_bytes() * (N - 1), + ) + for nr in range(0, N): + for mr in range(1, M): + ret += self.copy_simd_register( + prefix=self.prefix(), + src=accumulators[M * nr], + dst=accumulators[M * nr + mr], + ) + + return ret + + def copy_simd_register(self, prefix, src, dst): + return f'mov {prefix}{dst}.16b, {prefix}{src}.16b\n' + + def clamp_min(self, reg, prefix): + max_reg = self.max_register() + return f'fmin {prefix}{reg}.4s, {max_reg}.4s, {prefix}{reg}.4s\n' + + def clamp_max(self, reg, prefix): + min_reg = self.min_register() + return f'fmax {prefix}{reg}.4s, {min_reg}.4s, {prefix}{reg}.4s\n' + + def store( + self, + M, + N, + ): + accumulators = self.acc_registers() + cm_registers = self.cm_registers() + nc_reg = self.nc_register() + nc_lo = self.register_map_byte(nc_reg) + N_COUNT = N // self.n_step() + asm_string = """ + # Check whether full or partial store. + cmp {nc}, {n_step} + b.lo tail_{N_2}\n""".format(n_step=N, N_2=N // 2, nc=nc_reg) + for mr in range(0, M): + asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}], 32\n'.format( + ACC=accumulators[mr], + ACC_1=accumulators[M + mr], + c_reg=cm_registers[mr], + ) + for nr in range(2, N_COUNT, 2): + asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}], 32\n'.format( + ACC=accumulators[M * 2 + mr], + ACC_1=accumulators[M * 3 + mr], + c_reg=cm_registers[mr], + ) + for mr in range(0, M): + asm_string += 'sub {AM_PTR}, {AM_PTR}, {kc_register}\n'.format(AM_PTR=self.am_registers()[mr], kc_register=self.kc_register()) + CHECK = """ + sub {nc}, {nc}, {n_step} + b.ne outer_loop + b return""".format(n_step=N, nc=nc_reg) + asm_string += CHECK + N = N // 2 + if N * 2 > self.n_step(): + if N == 8: + asm_string += """ +\ntail_8: + tbz {nc_lo}, 3, tail_4\n""".format(nc_lo=nc_lo) + for mr in range(0, M): + asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}], 32\n'.format( + ACC=accumulators[mr], + ACC_1=accumulators[mr + M], + c_reg=cm_registers[mr], + ) + for mr in range(0, M): + asm_string += 'mov v{ACC0}.16b, v{ACC1}.16b\n'.format( + ACC0=accumulators[mr], ACC1=accumulators[mr + 2 * M] + ) + asm_string += 'mov v{ACC0}.16b, v{ACC1}.16b\n'.format( + ACC0=accumulators[mr + M], ACC1=accumulators[mr + 3 * M] + ) + asm_string += """ +\ntail_4: + tbz {nc_lo}, 2, tail_2\n""".format(nc_lo=nc_lo) + for mr in range(0, M): + asm_string += 'str q{ACC}, [{c_reg}], 16\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr] + ) + for mr in range(0, M): + asm_string += 'mov v{ACC0}.16b, v{ACC1}.16b\n'.format( + ACC0=accumulators[mr], ACC1=accumulators[mr + M] + ) + asm_string += """ +\ntail_2: + tbz {nc_lo}, 1, tail_1\n""".format(nc_lo=nc_lo) + for mr in range(0, M): + asm_string += 'str d{ACC}, [{c_reg}], 8\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr] + ) + for mr in range(0, M): + asm_string += 'dup d{ACC}, v{ACC}.d[1]\n'.format(ACC=accumulators[mr]) + asm_string += """ +\ntail_1: + tbz {nc_lo}, 0, return\n""".format(nc_lo=nc_lo) + for mr in range(0, M): + asm_string += 'str s{ACC}, [{c_reg}]\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr] + ) + + return asm_string diff --git a/gemm_compiler/x64_template.py b/gemm_compiler/x64_template.py new file mode 100644 index 000000000000..637cd565beb6 --- /dev/null +++ b/gemm_compiler/x64_template.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from gemm_compiler import base_architecture as base_architecture + +"""All non SIMD features for x64.""" + + +class X64(base_architecture.BaseArchitecture): + + def astride_register(self): + return 'r8' + + def kc_register(self): + return 'rdx' + + def k_register(self): + return 'r11' + + def cm_stride_register(self): + return 'r11' + + def am_registers(self): + return [self.a_ptr_register()] + [ + 'rax', + 'r15', + 'r14', + 'r12', + 'r10', + 'r13', + 'rbx', + 'rbp', + 'r8', + 'rdi', + ] + + def a_ptr_register(self): + return 'rsi' + + def c_ptr_register(self): + return 'rsi' + + def cm_registers(self): + return [self.c_ptr_register()] + [ + 'rax', + 'r15', + 'r14', + 'r12', + 'r10', + 'r13', + 'rbx', + 'rbp', + 'r8', + 'rdi', + ] + + def acc_registers(self): + return [ + 'mm7', + 'mm8', + 'mm9', + 'mm14', + 'mm15', + 'mm16', + 'mm17', + 'mm18', + 'mm19', + 'mm20', + 'mm21', + 'mm22', + 'mm23', + 'mm24', + 'mm25', + 'mm26', + 'mm27', + 'mm28', + 'mm29', + 'mm30', + 'mm12', + 'mm13', + ] + + def w_ptr_register(self): + return 'r9' + + def min_register(self): + return 'mm0' + + def max_register(self): + return 'mm1' + + def nc_register(self): + return 'rcx' + + def mr_register(self): + return 'rdi' + + def tmp_gp_registers(self): + return ['rdi', 'r11'] + + def register_map_byte(self, reg): + """Maps 64 bit register names to their low 8 bits.""" + map = { + 'rax': 'al', + 'rcx': 'cl', + 'rdx': 'dl', + 'rbx': 'bl', + 'rsi': 'sil', + 'rdi': 'dil', + 'rsp': 'spl', + 'rbp': 'bpl', + 'r8': 'r8b', + 'r9': 'r9b', + 'r10': 'r10b', + 'r11': 'r11b', + 'r12': 'r12b', + 'r13': 'r13b', + 'r14': 'r14b', + 'r15': 'r15b', + } + return map[reg] + + def register_map_dword(self, reg): + """Maps 64 bit register names to their low 32 bits.""" + map = { + 'rax': 'eax', + 'rcx': 'ecx', + 'rdx': 'edx', + 'rbx': 'ebx', + 'rsi': 'esi', + 'rdi': 'edi', + 'rsp': 'esp', + 'rbp': 'ebp', + 'r8': 'r8d', + 'r9': 'r9d', + 'r10': 'r10d', + 'r11': 'r11d', + 'r12': 'r12d', + 'r13': 'r13d', + 'r14': 'r14d', + 'r15': 'r15d', + } + return map[reg] + + def jump_to_label(self, label): + return f'jmp {label}' + + def function_name(self, M, N, isa): + return f'xnn_f32_gemm_minmax_ukernel_{M}x{N}__asm_amd64_{isa}_broadcast\n' + + def header(self, M, N, prefix, isa): + HEADER = '#include "xnnpack/assembly.h"\n\n' + + HEADER += 'BEGIN_FUNCTION ' + self.function_name(M, N, isa) + HEADER += """ + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss {prefix}mm0, DWORD PTR [r13] + vbroadcastss {prefix}mm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64]\n""".format(M=M, N=N, prefix=prefix, isa=isa) + return HEADER + + def input_output_register_setup(self, M): + registers = self.am_registers() + a_stride = self.astride_register() + c_stride = self.cm_stride_register() + INPUT_OUTPUT_REGISTER_SETUP = """ + # Clamp a & c pointers if mr <= {M} + mov {aM}, {aM_1} + add {aM}, {A_STRIDE} + mov {cM}, {cM_1} + add {cM}, {C_STRIDE} + cmp {mr_reg}, {M} + cmovle {aM}, {aM_1} + cmovle {cM}, {cM_1}\n""" + INPUT_OUTPUT_REGISTER_PUSH = """ + mov [rsp - {a_rsp_offset}], {aM} + mov [rsp - {c_rsp_offset}], {cM}\n""" + ret = '' + if self.stack_size(M) != 0: + ret += """sub rsp, {stack_size}\n""".format( + stack_size=self.stack_size(M) + ) + # Write rsi & r10 if required to the stack. + if M > self.max_M_before_spilling(): + ret += ( + '# Write rsi (a pointer) to the stack as we need the register.\n' + ) + ret += 'mov [rsp - 128], rsi\n' + ret += ( + '# Write r10 (c pointer) to the stack as we need the register.\n' + ) + ret += 'mov [rsp - 136], r10\n' + for mr in range(1, M): + # cycle size of 2 if required + if M > self.max_M_before_spilling(): + a_pos = mr % 2 + c_pos = (mr % 2) + self.max_M_before_spilling() + a_pos_1 = (mr + 1) % 2 + c_pos_1 = ((mr + 1) % 2) + self.max_M_before_spilling() + else: + a_pos = mr + c_pos = mr + self.max_M_before_spilling() + a_pos_1 = a_pos - 1 + c_pos_1 = c_pos - 1 + a_rsp_offset = 144 + (mr - 1) * 16 + ret += INPUT_OUTPUT_REGISTER_SETUP.format( + M=mr, + aM=registers[a_pos], + aM_1=registers[a_pos_1], + cM=registers[c_pos], + cM_1=registers[c_pos_1], + A_STRIDE=a_stride, + C_STRIDE=c_stride, + mr_reg=self.mr_register(), + a_rsp_offset=a_rsp_offset, + c_rsp_offset=a_rsp_offset + 8, + ) + if M > self.max_M_before_spilling(): + ret += INPUT_OUTPUT_REGISTER_PUSH.format( + M=mr, + aM=registers[a_pos], + aM_1=registers[a_pos_1], + cM=registers[c_pos], + cM_1=registers[c_pos_1], + A_STRIDE=a_stride, + C_STRIDE=c_stride, + mr_reg=self.mr_register(), + a_rsp_offset=a_rsp_offset, + c_rsp_offset=a_rsp_offset + 8, + ) + + return ret + + def max_M_before_spilling(self): + return 5 + + def read_a_registers(self, M): + registers = self.am_registers() + if M <= self.max_M_before_spilling(): + return '' + ret = '# Read a pointers from stack into GP registers.\n' + POP_A = 'mov {aM}, [rsp - {a_rsp_offset}]\n' + for mr in range(0, M): + a_rsp_offset = 128 + mr * 16 + ret += POP_A.format(aM=registers[mr], a_rsp_offset=a_rsp_offset) + ret += '\n' + return ret + + def increment_ptr(self, ptr, step): + return f'add {ptr}, {step}\n' + + def initialize_k_register(self, reg): + return f'mov {reg}, 0\n' + + def cmp_k_and_jump_if_less(self, label): + kc_register = self.kc_register() + k_register = self.k_register() + return """ + add {k_register}, 4 + cmp {kc_register}, {k_register} + jne {label}\n""".format( + label=label, k_register=k_register, kc_register=kc_register + ) + + def load_from_stack(self, reg, offset): + """Load 8 bytes from the given offset from the stack pointer to reg.""" + return f'mov {reg}, [rsp - {offset}]\n' + + def epilogue(self, M, N, isa): + restore_stack = '\nreturn:\n' + if isa.stack_size(M) != 0: + restore_stack += 'add rsp, {stack_ptr_sub}\n'.format( + stack_ptr_sub=isa.stack_size(M) + ) + restore_stack += """ + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION {function_name}""".format( + M=M, N=N, function_name=isa.function_name(M, N, isa.isa()) + ) + return restore_stack + + def stack_size(self, M): + """Returns the required stack storage space.""" + return 0 + + def inner_loop(self, M, N): + if M > self.max_M_before_spilling(): + return self.inner_loop_spill_gp(M, N) + else: + return self.inner_loop_small_M_N(M, N) diff --git a/gen/aarch64_microkernels.bzl b/gen/aarch64_microkernels.bzl index 52db94584237..d03ba1a5d09e 100644 --- a/gen/aarch64_microkernels.bzl +++ b/gen/aarch64_microkernels.bzl @@ -120,6 +120,7 @@ NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS = [ "src/f32-dwconv/f32-dwconv-9p4c-minmax-asm-aarch64-neonfma.S", "src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neon-ld128-acc2-prfm.S", "src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neon-ld128-acc2.S", + "src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S", "src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc2-prfm.S", "src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc2.S", "src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc4-prfm.S", @@ -131,15 +132,24 @@ NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS = [ "src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128-prfm.S", "src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128.S", "src/f32-gemm/gen/f32-gemm-1x12-minmax-asm-aarch64-neonfma-cortex-a53.S", + "src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld32.S", + "src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S", + "src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld32.S", + "src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S", + "src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld32.S", "src/f32-gemm/gen/f32-gemm-4x1-minmax-asm-aarch64-neonfma-ld64.S", "src/f32-gemm/gen/f32-gemm-4x1-minmax-asm-aarch64-neonfma-ld128.S", "src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-cortex-a75-prfm.S", "src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-cortex-a75.S", "src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-ld64.S", + "src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S", "src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld64.S", "src/f32-gemm/gen/f32-gemm-4x12-minmax-asm-aarch64-neonfma-cortex-a53.S", + "src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld32.S", "src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-cortex-a75-prfm.S", "src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-cortex-a75.S", + "src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S", + "src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld32.S", "src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-cortex-a75.S", "src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-ld64.S", "src/f32-gemm/gen/f32-gemm-goi-1x8-minmax-asm-aarch64-neonfma-ld128-prfm.S", @@ -225,6 +235,14 @@ NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS = [ "src/f32-qc8w-gemm/gen/f32-qc8w-gemm-4x8-minmax-asm-aarch64-neonfma-ld64.S", "src/f32-qc8w-gemm/gen/f32-qc8w-gemm-4x8-minmax-asm-aarch64-neonfma-ld128.S", "src/f32-qc8w-gemm/gen/f32-qc8w-gemm-6x8-minmax-asm-aarch64-neonfma-ld64.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld32.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld32.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld32.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld32.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld32.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld32.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld32.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld32.S", "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-asm-aarch64-neondot-ld64.S", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53.S", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-asm-aarch64-neondot-ld32.S", diff --git a/gen/amd64_microkernels.bzl b/gen/amd64_microkernels.bzl new file mode 100644 index 000000000000..4d06fff1aa0f --- /dev/null +++ b/gen/amd64_microkernels.bzl @@ -0,0 +1,68 @@ +""" +Microkernel filenames lists for amd64. + +Auto-generated file. Do not edit! + Generator: tools/update-microkernels.py +""" + +PROD_AMD64_ASM_MICROKERNEL_SRCS = [ +] + +NON_PROD_AMD64_ASM_MICROKERNEL_SRCS = [ + "src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x64-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x32-minmax-asm-amd64-avx512vnni.S", +] + +AMD64_ASM_MICROKERNEL_SRCS = PROD_AMD64_ASM_MICROKERNEL_SRCS + NON_PROD_AMD64_ASM_MICROKERNEL_SRCS diff --git a/gen/microkernels.bzl b/gen/microkernels.bzl index 988614c49816..24c0255ad1f1 100644 --- a/gen/microkernels.bzl +++ b/gen/microkernels.bzl @@ -7,6 +7,7 @@ Auto-generated file. Do not edit! load("aarch32_microkernels.bzl", _AARCH32_ASM_MICROKERNEL_SRCS = "AARCH32_ASM_MICROKERNEL_SRCS", _NON_PROD_AARCH32_ASM_MICROKERNEL_SRCS = "NON_PROD_AARCH32_ASM_MICROKERNEL_SRCS", _PROD_AARCH32_ASM_MICROKERNEL_SRCS = "PROD_AARCH32_ASM_MICROKERNEL_SRCS") load("aarch64_microkernels.bzl", _AARCH64_ASM_MICROKERNEL_SRCS = "AARCH64_ASM_MICROKERNEL_SRCS", _NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS = "NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS", _PROD_AARCH64_ASM_MICROKERNEL_SRCS = "PROD_AARCH64_ASM_MICROKERNEL_SRCS") +load("amd64_microkernels.bzl", _AMD64_ASM_MICROKERNEL_SRCS = "AMD64_ASM_MICROKERNEL_SRCS", _NON_PROD_AMD64_ASM_MICROKERNEL_SRCS = "NON_PROD_AMD64_ASM_MICROKERNEL_SRCS", _PROD_AMD64_ASM_MICROKERNEL_SRCS = "PROD_AMD64_ASM_MICROKERNEL_SRCS") load("armsimd32_microkernels.bzl", _ALL_ARMSIMD32_MICROKERNEL_SRCS = "ALL_ARMSIMD32_MICROKERNEL_SRCS", _NON_PROD_ARMSIMD32_MICROKERNEL_SRCS = "NON_PROD_ARMSIMD32_MICROKERNEL_SRCS", _PROD_ARMSIMD32_MICROKERNEL_SRCS = "PROD_ARMSIMD32_MICROKERNEL_SRCS") load("avx256skx_microkernels.bzl", _ALL_AVX256SKX_MICROKERNEL_SRCS = "ALL_AVX256SKX_MICROKERNEL_SRCS", _NON_PROD_AVX256SKX_MICROKERNEL_SRCS = "NON_PROD_AVX256SKX_MICROKERNEL_SRCS", _PROD_AVX256SKX_MICROKERNEL_SRCS = "PROD_AVX256SKX_MICROKERNEL_SRCS") load("avx256vnni_microkernels.bzl", _ALL_AVX256VNNI_MICROKERNEL_SRCS = "ALL_AVX256VNNI_MICROKERNEL_SRCS", _NON_PROD_AVX256VNNI_MICROKERNEL_SRCS = "NON_PROD_AVX256VNNI_MICROKERNEL_SRCS", _PROD_AVX256VNNI_MICROKERNEL_SRCS = "PROD_AVX256VNNI_MICROKERNEL_SRCS") @@ -103,8 +104,10 @@ ALL_SSSE3_MICROKERNEL_SRCS = _ALL_SSSE3_MICROKERNEL_SRCS ALL_WASMRELAXEDSIMD_MICROKERNEL_SRCS = _ALL_WASMRELAXEDSIMD_MICROKERNEL_SRCS ALL_WASMSIMD_MICROKERNEL_SRCS = _ALL_WASMSIMD_MICROKERNEL_SRCS ALL_WASM_MICROKERNEL_SRCS = _ALL_WASM_MICROKERNEL_SRCS +AMD64_ASM_MICROKERNEL_SRCS = _AMD64_ASM_MICROKERNEL_SRCS NON_PROD_AARCH32_ASM_MICROKERNEL_SRCS = _NON_PROD_AARCH32_ASM_MICROKERNEL_SRCS NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS = _NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS +NON_PROD_AMD64_ASM_MICROKERNEL_SRCS = _NON_PROD_AMD64_ASM_MICROKERNEL_SRCS NON_PROD_ARMSIMD32_MICROKERNEL_SRCS = _NON_PROD_ARMSIMD32_MICROKERNEL_SRCS NON_PROD_AVX256SKX_MICROKERNEL_SRCS = _NON_PROD_AVX256SKX_MICROKERNEL_SRCS NON_PROD_AVX256VNNIGFNI_MICROKERNEL_SRCS = _NON_PROD_AVX256VNNIGFNI_MICROKERNEL_SRCS @@ -155,6 +158,7 @@ NON_PROD_WASMSIMD_MICROKERNEL_SRCS = _NON_PROD_WASMSIMD_MICROKERNEL_SRCS NON_PROD_WASM_MICROKERNEL_SRCS = _NON_PROD_WASM_MICROKERNEL_SRCS PROD_AARCH32_ASM_MICROKERNEL_SRCS = _PROD_AARCH32_ASM_MICROKERNEL_SRCS PROD_AARCH64_ASM_MICROKERNEL_SRCS = _PROD_AARCH64_ASM_MICROKERNEL_SRCS +PROD_AMD64_ASM_MICROKERNEL_SRCS = _PROD_AMD64_ASM_MICROKERNEL_SRCS PROD_ARMSIMD32_MICROKERNEL_SRCS = _PROD_ARMSIMD32_MICROKERNEL_SRCS PROD_AVX256SKX_MICROKERNEL_SRCS = _PROD_AVX256SKX_MICROKERNEL_SRCS PROD_AVX256VNNIGFNI_MICROKERNEL_SRCS = _PROD_AVX256VNNIGFNI_MICROKERNEL_SRCS @@ -306,6 +310,7 @@ NON_PROD_C_SRCS_FOR_ARCH = { PROD_ASM_SRCS_FOR_ARCH = { "aarch32": PROD_AARCH32_ASM_MICROKERNEL_SRCS, "aarch64": PROD_AARCH64_ASM_MICROKERNEL_SRCS, + "amd64": PROD_AMD64_ASM_MICROKERNEL_SRCS, "wasm32": PROD_WASM32_ASM_MICROKERNEL_SRCS, "wasmrelaxedsimd32": PROD_WASMRELAXEDSIMD32_ASM_MICROKERNEL_SRCS, "wasmsimd32": PROD_WASMSIMD32_ASM_MICROKERNEL_SRCS, @@ -314,6 +319,7 @@ PROD_ASM_SRCS_FOR_ARCH = { NON_PROD_ASM_SRCS_FOR_ARCH = { "aarch32": NON_PROD_AARCH32_ASM_MICROKERNEL_SRCS, "aarch64": NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS, + "amd64": NON_PROD_AMD64_ASM_MICROKERNEL_SRCS, "wasm32": NON_PROD_WASM32_ASM_MICROKERNEL_SRCS, "wasmrelaxedsimd32": NON_PROD_WASMRELAXEDSIMD32_ASM_MICROKERNEL_SRCS, "wasmsimd32": NON_PROD_WASMSIMD32_ASM_MICROKERNEL_SRCS, diff --git a/src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..b2da35107d20 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,298 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 272], rax + mov [rsp - 280], r13 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + mov r8, [rsp - 272] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm18, zmm7 + vmovaps zmm19, zmm7 + vmovaps zmm20, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rbp + r11] + vfmadd231ps zmm19, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r8 + r11] + vfmadd231ps zmm20, zmm2, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + mov r8, [rsp - 280] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm7 + vmovups [rax], zmm8 + vmovups [r15], zmm9 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + vmovups [r13], zmm17 + vmovups [rbx], zmm18 + vmovups [rbp], zmm19 + vmovups [r8], zmm20 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + add r8, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + mov [rsp - 280], r8 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbp]{k1}, zmm19 + vmovups ZMMWORD PTR [r8]{k1}, zmm20 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..efd156a9ab43 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,361 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 272], rax + mov [rsp - 280], r13 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + mov r8, [rsp - 272] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm21, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm18, zmm7 + vmovaps zmm19, zmm7 + vmovaps zmm20, zmm7 + vmovaps zmm22, zmm21 + vmovaps zmm23, zmm21 + vmovaps zmm24, zmm21 + vmovaps zmm25, zmm21 + vmovaps zmm26, zmm21 + vmovaps zmm27, zmm21 + vmovaps zmm28, zmm21 + vmovaps zmm29, zmm21 + vmovaps zmm30, zmm21 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm21, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm22, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm23, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm24, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm25, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm26, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vfmadd231ps zmm27, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm10 + vfmadd231ps zmm28, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rbp + r11] + vfmadd231ps zmm19, zmm2, zmm10 + vfmadd231ps zmm29, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r8 + r11] + vfmadd231ps zmm20, zmm2, zmm10 + vfmadd231ps zmm30, zmm2, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vminps zmm28, zmm1, zmm28 + vminps zmm29, zmm1, zmm29 + vminps zmm30, zmm1, zmm30 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + vmaxps zmm28, zmm0, zmm28 + vmaxps zmm29, zmm0, zmm29 + vmaxps zmm30, zmm0, zmm30 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + mov r8, [rsp - 280] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm7 + vmovups [rsi + 64], zmm21 + vmovups [rax], zmm8 + vmovups [rax + 64], zmm22 + vmovups [r15], zmm9 + vmovups [r15 + 64], zmm23 + vmovups [r14], zmm14 + vmovups [r14 + 64], zmm24 + vmovups [r12], zmm15 + vmovups [r12 + 64], zmm25 + vmovups [r10], zmm16 + vmovups [r10 + 64], zmm26 + vmovups [r13], zmm17 + vmovups [r13 + 64], zmm27 + vmovups [rbx], zmm18 + vmovups [rbx + 64], zmm28 + vmovups [rbp], zmm19 + vmovups [rbp + 64], zmm29 + vmovups [r8], zmm20 + vmovups [r8 + 64], zmm30 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + add r8, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + mov [rsp - 280], r8 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm25 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm26 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm27 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm28 + vmovups ZMMWORD PTR [rbp]{k1}, zmm19 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm29 + vmovups ZMMWORD PTR [r8]{k1}, zmm20 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm30 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..c8c4c288e71e --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,321 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 272], rax + mov [rsp - 280], r13 + + # Clamp a & c pointers if mr <= 10 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 10 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 288], rsi + mov [rsp - 296], r10 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + mov r8, [rsp - 272] + mov rdi, [rsp - 288] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm18, zmm7 + vmovaps zmm19, zmm7 + vmovaps zmm20, zmm7 + vmovaps zmm21, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rbp + r11] + vfmadd231ps zmm19, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r8 + r11] + vfmadd231ps zmm20, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rdi + r11] + vfmadd231ps zmm21, zmm2, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + mov r8, [rsp - 280] + mov rdi, [rsp - 296] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm7 + vmovups [rax], zmm8 + vmovups [r15], zmm9 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + vmovups [r13], zmm17 + vmovups [rbx], zmm18 + vmovups [rbp], zmm19 + vmovups [r8], zmm20 + vmovups [rdi], zmm21 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + add r8, 64 + add rdi, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + mov [rsp - 280], r8 + mov [rsp - 296], rdi + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbp]{k1}, zmm19 + vmovups ZMMWORD PTR [r8]{k1}, zmm20 + vmovups ZMMWORD PTR [rdi]{k1}, zmm21 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..8e090a0b1822 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,390 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 272], rax + mov [rsp - 280], r13 + + # Clamp a & c pointers if mr <= 10 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 10 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 288], rsi + mov [rsp - 296], r10 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + mov r8, [rsp - 272] + mov rdi, [rsp - 288] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm22, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm18, zmm7 + vmovaps zmm19, zmm7 + vmovaps zmm20, zmm7 + vmovaps zmm21, zmm7 + vmovaps zmm23, zmm22 + vmovaps zmm24, zmm22 + vmovaps zmm25, zmm22 + vmovaps zmm26, zmm22 + vmovaps zmm27, zmm22 + vmovaps zmm28, zmm22 + vmovaps zmm29, zmm22 + vmovaps zmm30, zmm22 + vmovaps zmm12, zmm22 + vmovaps zmm13, zmm22 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm22, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm23, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm24, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm25, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm26, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm27, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vfmadd231ps zmm28, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm10 + vfmadd231ps zmm29, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rbp + r11] + vfmadd231ps zmm19, zmm2, zmm10 + vfmadd231ps zmm30, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r8 + r11] + vfmadd231ps zmm20, zmm2, zmm10 + vfmadd231ps zmm12, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rdi + r11] + vfmadd231ps zmm21, zmm2, zmm10 + vfmadd231ps zmm13, zmm2, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vminps zmm28, zmm1, zmm28 + vminps zmm29, zmm1, zmm29 + vminps zmm30, zmm1, zmm30 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + vmaxps zmm28, zmm0, zmm28 + vmaxps zmm29, zmm0, zmm29 + vmaxps zmm30, zmm0, zmm30 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + mov r8, [rsp - 280] + mov rdi, [rsp - 296] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm7 + vmovups [rsi + 64], zmm22 + vmovups [rax], zmm8 + vmovups [rax + 64], zmm23 + vmovups [r15], zmm9 + vmovups [r15 + 64], zmm24 + vmovups [r14], zmm14 + vmovups [r14 + 64], zmm25 + vmovups [r12], zmm15 + vmovups [r12 + 64], zmm26 + vmovups [r10], zmm16 + vmovups [r10 + 64], zmm27 + vmovups [r13], zmm17 + vmovups [r13 + 64], zmm28 + vmovups [rbx], zmm18 + vmovups [rbx + 64], zmm29 + vmovups [rbp], zmm19 + vmovups [rbp + 64], zmm30 + vmovups [r8], zmm20 + vmovups [r8 + 64], zmm12 + vmovups [rdi], zmm21 + vmovups [rdi + 64], zmm13 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + add r8, 128 + add rdi, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + mov [rsp - 280], r8 + mov [rsp - 296], rdi + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm25 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm26 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm27 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm28 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm29 + vmovups ZMMWORD PTR [rbp]{k1}, zmm19 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm30 + vmovups ZMMWORD PTR [r8]{k1}, zmm20 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm12 + vmovups ZMMWORD PTR [rdi]{k1}, zmm21 + vmovups ZMMWORD PTR [rdi + 64]{k2}, zmm13 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..c6827b2b6698 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,96 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with the biases. + ldp q11, q12, [x5, 0] + ldp q13, q14, [x5, 32] + add x5, x5, 64 + +inner_loop: + ldr s2, [x3], 4 + ldp q7, q8, [x5], 32 + ldp q9, q10, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v8.4s, v2.s[0] + fmla v13.4s, v9.4s, v2.s[0] + fmla v14.4s, v10.4s, v2.s[0] + subs x20, x20, 4 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q11, q12, [x6], 32 + stp q13, q14, [x6], 32 + sub x3, x3, x2 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q11, q12, [x6], 32 + mov v11.16b, v13.16b + mov v12.16b, v14.16b + + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + mov v11.16b, v12.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + dup d11, v11.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..de9feca9f79a --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,78 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vmaxps zmm7, zmm0, zmm7 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm7 + add r10, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..8bf1ff26b55e --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,87 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm8, zmm2, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm8 + add r10, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm8 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..f0e2e17c3d89 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,106 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm14, [r9 + 192] + add r9, 256 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm12, [r9 + 128] + vmovaps zmm13, [r9 + 192] + add r9, 256 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm8, zmm2, zmm11 + vfmadd231ps zmm9, zmm2, zmm12 + vfmadd231ps zmm14, zmm2, zmm13 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm8 + vmovups [r10 + 128], zmm9 + vmovups [r10 + 192], zmm14 + add r10, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm8 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm9 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm14 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..d6aec57123e2 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,80 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with the biases. + ldp q11, q12, [x5, 0] + add x5, x5, 32 + +inner_loop: + ldr s2, [x3], 4 + ldp q7, q8, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v8.4s, v2.s[0] + subs x20, x20, 4 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q11, q12, [x6], 32 + sub x3, x3, x2 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + mov v11.16b, v12.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + dup d11, v11.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..dd3cb73756ed --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,131 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + # Setup and alias a & c pointers. + add x9, x3, x4 + add x13, x6, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with the biases. + ldp q11, q13, [x5, 0] + ldp q15, q17, [x5, 32] + mov v12.16b, v11.16b + mov v14.16b, v13.16b + mov v16.16b, v15.16b + mov v18.16b, v17.16b + add x5, x5, 64 + +inner_loop: + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldp q7, q8, [x5], 32 + ldp q9, q10, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v7.4s, v3.s[0] + fmla v13.4s, v8.4s, v2.s[0] + fmla v14.4s, v8.4s, v3.s[0] + fmla v15.4s, v9.4s, v2.s[0] + fmla v16.4s, v9.4s, v3.s[0] + fmla v17.4s, v10.4s, v2.s[0] + fmla v18.4s, v10.4s, v3.s[0] + subs x20, x20, 4 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q11, q13, [x6], 32 + stp q15, q17, [x6], 32 + stp q12, q14, [x13], 32 + stp q16, q18, [x13], 32 + sub x3, x3, x2 + sub x9, x9, x2 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q11, q13, [x6], 32 + stp q12, q14, [x13], 32 + mov v11.16b, v15.16b + mov v13.16b, v17.16b + mov v12.16b, v16.16b + mov v14.16b, v18.16b + + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + str q12, [x13], 16 + mov v11.16b, v13.16b + mov v12.16b, v14.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + str d12, [x13], 8 + dup d11, v11.d[1] + dup d12, v12.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + str s12, [x13] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..36236780321b --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,95 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm7 + vmovups [r13], zmm8 + add r10, 64 + add r13, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..f80ba7448b11 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,110 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm9, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm14, zmm9 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm9, zmm2, zmm11 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm14, zmm3, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm9 + vmovups [r13], zmm8 + vmovups [r13 + 64], zmm14 + add r10, 128 + add r13, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm9 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm14 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..105cf33322eb --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,141 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm9, [r9 + 64] + vmovaps zmm15, [r9 + 128] + vmovaps zmm17, [r9 + 192] + vmovaps zmm8, zmm7 + vmovaps zmm14, zmm9 + vmovaps zmm16, zmm15 + vmovaps zmm18, zmm17 + add r9, 256 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm12, [r9 + 128] + vmovaps zmm13, [r9 + 192] + add r9, 256 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm9, zmm2, zmm11 + vfmadd231ps zmm15, zmm2, zmm12 + vfmadd231ps zmm17, zmm2, zmm13 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm14, zmm3, zmm11 + vfmadd231ps zmm16, zmm3, zmm12 + vfmadd231ps zmm18, zmm3, zmm13 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm9 + vmovups [r10 + 128], zmm15 + vmovups [r10 + 192], zmm17 + vmovups [r13], zmm8 + vmovups [r13 + 64], zmm14 + vmovups [r13 + 128], zmm16 + vmovups [r13 + 192], zmm18 + add r10, 256 + add r13, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm9 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm15 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm17 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm14 + vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm16 + vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm18 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..bba3de14cf77 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,103 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + # Setup and alias a & c pointers. + add x9, x3, x4 + add x13, x6, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with the biases. + ldp q11, q13, [x5, 0] + mov v12.16b, v11.16b + mov v14.16b, v13.16b + add x5, x5, 32 + +inner_loop: + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldp q7, q8, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v7.4s, v3.s[0] + fmla v13.4s, v8.4s, v2.s[0] + fmla v14.4s, v8.4s, v3.s[0] + subs x20, x20, 4 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q11, q13, [x6], 32 + stp q12, q14, [x13], 32 + sub x3, x3, x2 + sub x9, x9, x2 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + str q12, [x13], 16 + mov v11.16b, v13.16b + mov v12.16b, v14.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + str d12, [x13], 8 + dup d11, v11.d[1] + dup d12, v12.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + str s12, [x13] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..9b2af6912f4f --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,163 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x13, x6, x7 + add x14, x13, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with the biases. + ldp q11, q14, [x5, 0] + ldp q17, q20, [x5, 32] + mov v12.16b, v11.16b + mov v13.16b, v11.16b + mov v15.16b, v14.16b + mov v16.16b, v14.16b + mov v18.16b, v17.16b + mov v19.16b, v17.16b + mov v21.16b, v20.16b + mov v22.16b, v20.16b + add x5, x5, 64 + +inner_loop: + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 + ldp q7, q8, [x5], 32 + ldp q9, q10, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v7.4s, v3.s[0] + fmla v13.4s, v7.4s, v4.s[0] + fmla v14.4s, v8.4s, v2.s[0] + fmla v15.4s, v8.4s, v3.s[0] + fmla v16.4s, v8.4s, v4.s[0] + fmla v17.4s, v9.4s, v2.s[0] + fmla v18.4s, v9.4s, v3.s[0] + fmla v19.4s, v9.4s, v4.s[0] + fmla v20.4s, v10.4s, v2.s[0] + fmla v21.4s, v10.4s, v3.s[0] + fmla v22.4s, v10.4s, v4.s[0] + subs x20, x20, 4 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmin v19.4s, v1.4s, v19.4s + fmin v20.4s, v1.4s, v20.4s + fmin v21.4s, v1.4s, v21.4s + fmin v22.4s, v1.4s, v22.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + fmax v19.4s, v0.4s, v19.4s + fmax v20.4s, v0.4s, v20.4s + fmax v21.4s, v0.4s, v21.4s + fmax v22.4s, v0.4s, v22.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q11, q14, [x6], 32 + stp q17, q20, [x6], 32 + stp q12, q15, [x13], 32 + stp q18, q21, [x13], 32 + stp q13, q16, [x14], 32 + stp q19, q22, [x14], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q11, q14, [x6], 32 + stp q12, q15, [x13], 32 + stp q13, q16, [x14], 32 + mov v11.16b, v17.16b + mov v14.16b, v20.16b + mov v12.16b, v18.16b + mov v15.16b, v21.16b + mov v13.16b, v19.16b + mov v16.16b, v22.16b + + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + str q12, [x13], 16 + str q13, [x14], 16 + mov v11.16b, v14.16b + mov v12.16b, v15.16b + mov v13.16b, v16.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + str d12, [x13], 8 + str d13, [x14], 8 + dup d11, v11.d[1] + dup d12, v12.d[1] + dup d13, v13.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + str s12, [x13] + str s13, [x14] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..53ca4d7700ee --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,112 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm7 + vmovups [r13], zmm8 + vmovups [rbx], zmm9 + add r10, 64 + add r13, 64 + add rbx, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..5510c79e6df7 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,133 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm14, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm15, zmm14 + vmovaps zmm16, zmm14 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm14, zmm2, zmm11 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm15, zmm3, zmm11 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + vfmadd231ps zmm16, zmm4, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm14 + vmovups [r13], zmm8 + vmovups [r13 + 64], zmm15 + vmovups [rbx], zmm9 + vmovups [rbx + 64], zmm16 + add r10, 128 + add r13, 128 + add rbx, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm14 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm15 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm16 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..bd1641c2879a --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,176 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm14, [r9 + 64] + vmovaps zmm17, [r9 + 128] + vmovaps zmm20, [r9 + 192] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm15, zmm14 + vmovaps zmm16, zmm14 + vmovaps zmm18, zmm17 + vmovaps zmm19, zmm17 + vmovaps zmm21, zmm20 + vmovaps zmm22, zmm20 + add r9, 256 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm12, [r9 + 128] + vmovaps zmm13, [r9 + 192] + add r9, 256 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm14, zmm2, zmm11 + vfmadd231ps zmm17, zmm2, zmm12 + vfmadd231ps zmm20, zmm2, zmm13 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm15, zmm3, zmm11 + vfmadd231ps zmm18, zmm3, zmm12 + vfmadd231ps zmm21, zmm3, zmm13 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + vfmadd231ps zmm16, zmm4, zmm11 + vfmadd231ps zmm19, zmm4, zmm12 + vfmadd231ps zmm22, zmm4, zmm13 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm14 + vmovups [r10 + 128], zmm17 + vmovups [r10 + 192], zmm20 + vmovups [r13], zmm8 + vmovups [r13 + 64], zmm15 + vmovups [r13 + 128], zmm18 + vmovups [r13 + 192], zmm21 + vmovups [rbx], zmm9 + vmovups [rbx + 64], zmm16 + vmovups [rbx + 128], zmm19 + vmovups [rbx + 192], zmm22 + add r10, 256 + add r13, 256 + add rbx, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm14 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm17 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm20 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm15 + vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm18 + vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm21 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm19 + vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm22 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..26bebc208e7b --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,123 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x13, x6, x7 + add x14, x13, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with the biases. + ldp q11, q14, [x5, 0] + mov v12.16b, v11.16b + mov v13.16b, v11.16b + mov v15.16b, v14.16b + mov v16.16b, v14.16b + add x5, x5, 32 + +inner_loop: + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 + ldp q7, q8, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v7.4s, v3.s[0] + fmla v13.4s, v7.4s, v4.s[0] + fmla v14.4s, v8.4s, v2.s[0] + fmla v15.4s, v8.4s, v3.s[0] + fmla v16.4s, v8.4s, v4.s[0] + subs x20, x20, 4 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q11, q14, [x6], 32 + stp q12, q15, [x13], 32 + stp q13, q16, [x14], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + str q12, [x13], 16 + str q13, [x14], 16 + mov v11.16b, v14.16b + mov v12.16b, v15.16b + mov v13.16b, v16.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + str d12, [x13], 8 + str d13, [x14], 8 + dup d11, v11.d[1] + dup d12, v12.d[1] + dup d13, v13.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + str s12, [x13] + str s13, [x14] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..fb40896a6e4a --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,197 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x11, x10, x4 + add x13, x6, x7 + add x14, x13, x7 + add x15, x14, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + + cmp x0, 4 + csel x11, x10, x11, LO + csel x15, x14, x15, LO + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with the biases. + ldp q11, q15, [x5, 0] + ldp q19, q23, [x5, 32] + mov v12.16b, v11.16b + mov v13.16b, v11.16b + mov v14.16b, v11.16b + mov v16.16b, v15.16b + mov v17.16b, v15.16b + mov v18.16b, v15.16b + mov v20.16b, v19.16b + mov v21.16b, v19.16b + mov v22.16b, v19.16b + mov v24.16b, v23.16b + mov v25.16b, v23.16b + mov v26.16b, v23.16b + add x5, x5, 64 + +inner_loop: + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 + ldr s5, [x11], 4 + ldp q7, q8, [x5], 32 + ldp q9, q10, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v7.4s, v3.s[0] + fmla v13.4s, v7.4s, v4.s[0] + fmla v14.4s, v7.4s, v5.s[0] + fmla v15.4s, v8.4s, v2.s[0] + fmla v16.4s, v8.4s, v3.s[0] + fmla v17.4s, v8.4s, v4.s[0] + fmla v18.4s, v8.4s, v5.s[0] + fmla v19.4s, v9.4s, v2.s[0] + fmla v20.4s, v9.4s, v3.s[0] + fmla v21.4s, v9.4s, v4.s[0] + fmla v22.4s, v9.4s, v5.s[0] + fmla v23.4s, v10.4s, v2.s[0] + fmla v24.4s, v10.4s, v3.s[0] + fmla v25.4s, v10.4s, v4.s[0] + fmla v26.4s, v10.4s, v5.s[0] + subs x20, x20, 4 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmin v19.4s, v1.4s, v19.4s + fmin v20.4s, v1.4s, v20.4s + fmin v21.4s, v1.4s, v21.4s + fmin v22.4s, v1.4s, v22.4s + fmin v23.4s, v1.4s, v23.4s + fmin v24.4s, v1.4s, v24.4s + fmin v25.4s, v1.4s, v25.4s + fmin v26.4s, v1.4s, v26.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + fmax v19.4s, v0.4s, v19.4s + fmax v20.4s, v0.4s, v20.4s + fmax v21.4s, v0.4s, v21.4s + fmax v22.4s, v0.4s, v22.4s + fmax v23.4s, v0.4s, v23.4s + fmax v24.4s, v0.4s, v24.4s + fmax v25.4s, v0.4s, v25.4s + fmax v26.4s, v0.4s, v26.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q11, q15, [x6], 32 + stp q19, q23, [x6], 32 + stp q12, q16, [x13], 32 + stp q20, q24, [x13], 32 + stp q13, q17, [x14], 32 + stp q21, q25, [x14], 32 + stp q14, q18, [x15], 32 + stp q22, q26, [x15], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 + sub x11, x11, x2 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q11, q15, [x6], 32 + stp q12, q16, [x13], 32 + stp q13, q17, [x14], 32 + stp q14, q18, [x15], 32 + mov v11.16b, v19.16b + mov v15.16b, v23.16b + mov v12.16b, v20.16b + mov v16.16b, v24.16b + mov v13.16b, v21.16b + mov v17.16b, v25.16b + mov v14.16b, v22.16b + mov v18.16b, v26.16b + + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + str q12, [x13], 16 + str q13, [x14], 16 + str q14, [x15], 16 + mov v11.16b, v15.16b + mov v12.16b, v16.16b + mov v13.16b, v17.16b + mov v14.16b, v18.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + str d12, [x13], 8 + str d13, [x14], 8 + str d14, [x15], 8 + dup d11, v11.d[1] + dup d12, v12.d[1] + dup d13, v13.d[1] + dup d14, v14.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + str s12, [x13] + str s13, [x14] + str s14, [x15] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..d0938f5ae743 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,129 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + vbroadcastss zmm5, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm5, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm7 + vmovups [r13], zmm8 + vmovups [rbx], zmm9 + vmovups [rbp], zmm14 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbp]{k1}, zmm14 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..75764d9495ad --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,156 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm15, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm16, zmm15 + vmovaps zmm17, zmm15 + vmovaps zmm18, zmm15 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm15, zmm2, zmm11 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm16, zmm3, zmm11 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + vfmadd231ps zmm17, zmm4, zmm11 + vbroadcastss zmm5, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm5, zmm10 + vfmadd231ps zmm18, zmm5, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm15 + vmovups [r13], zmm8 + vmovups [r13 + 64], zmm16 + vmovups [rbx], zmm9 + vmovups [rbx + 64], zmm17 + vmovups [rbp], zmm14 + vmovups [rbp + 64], zmm18 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm15 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [rbp]{k1}, zmm14 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm18 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..4492f3fabef6 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,211 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm15, [r9 + 64] + vmovaps zmm19, [r9 + 128] + vmovaps zmm23, [r9 + 192] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm16, zmm15 + vmovaps zmm17, zmm15 + vmovaps zmm18, zmm15 + vmovaps zmm20, zmm19 + vmovaps zmm21, zmm19 + vmovaps zmm22, zmm19 + vmovaps zmm24, zmm23 + vmovaps zmm25, zmm23 + vmovaps zmm26, zmm23 + add r9, 256 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm12, [r9 + 128] + vmovaps zmm13, [r9 + 192] + add r9, 256 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm15, zmm2, zmm11 + vfmadd231ps zmm19, zmm2, zmm12 + vfmadd231ps zmm23, zmm2, zmm13 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm16, zmm3, zmm11 + vfmadd231ps zmm20, zmm3, zmm12 + vfmadd231ps zmm24, zmm3, zmm13 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + vfmadd231ps zmm17, zmm4, zmm11 + vfmadd231ps zmm21, zmm4, zmm12 + vfmadd231ps zmm25, zmm4, zmm13 + vbroadcastss zmm5, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm5, zmm10 + vfmadd231ps zmm18, zmm5, zmm11 + vfmadd231ps zmm22, zmm5, zmm12 + vfmadd231ps zmm26, zmm5, zmm13 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm15 + vmovups [r10 + 128], zmm19 + vmovups [r10 + 192], zmm23 + vmovups [r13], zmm8 + vmovups [r13 + 64], zmm16 + vmovups [r13 + 128], zmm20 + vmovups [r13 + 192], zmm24 + vmovups [rbx], zmm9 + vmovups [rbx + 64], zmm17 + vmovups [rbx + 128], zmm21 + vmovups [rbx + 192], zmm25 + vmovups [rbp], zmm14 + vmovups [rbp + 64], zmm18 + vmovups [rbp + 128], zmm22 + vmovups [rbp + 192], zmm26 + add r10, 256 + add r13, 256 + add rbx, 256 + add rbp, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm15 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm19 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm23 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm20 + vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm24 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm21 + vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm25 + vmovups ZMMWORD PTR [rbp]{k1}, zmm14 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rbp + 128]{k3}, zmm22 + vmovups ZMMWORD PTR [rbp + 192]{k4}, zmm26 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..70f9d83b2f50 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,145 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x11, x10, x4 + add x13, x6, x7 + add x14, x13, x7 + add x15, x14, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + + cmp x0, 4 + csel x11, x10, x11, LO + csel x15, x14, x15, LO + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with the biases. + ldp q11, q15, [x5, 0] + mov v12.16b, v11.16b + mov v13.16b, v11.16b + mov v14.16b, v11.16b + mov v16.16b, v15.16b + mov v17.16b, v15.16b + mov v18.16b, v15.16b + add x5, x5, 32 + +inner_loop: + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 + ldr s5, [x11], 4 + ldp q7, q8, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v7.4s, v3.s[0] + fmla v13.4s, v7.4s, v4.s[0] + fmla v14.4s, v7.4s, v5.s[0] + fmla v15.4s, v8.4s, v2.s[0] + fmla v16.4s, v8.4s, v3.s[0] + fmla v17.4s, v8.4s, v4.s[0] + fmla v18.4s, v8.4s, v5.s[0] + subs x20, x20, 4 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q11, q15, [x6], 32 + stp q12, q16, [x13], 32 + stp q13, q17, [x14], 32 + stp q14, q18, [x15], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 + sub x11, x11, x2 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + str q12, [x13], 16 + str q13, [x14], 16 + str q14, [x15], 16 + mov v11.16b, v15.16b + mov v12.16b, v16.16b + mov v13.16b, v17.16b + mov v14.16b, v18.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + str d12, [x13], 8 + str d13, [x14], 8 + str d14, [x15], 8 + dup d11, v11.d[1] + dup d12, v12.d[1] + dup d13, v13.d[1] + dup d14, v14.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + str s12, [x13] + str s13, [x14] + str s14, [x15] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..af9cccf406ec --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,229 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x11, x10, x4 + add x12, x11, x4 + add x13, x6, x7 + add x14, x13, x7 + add x15, x14, x7 + add x19, x15, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + + cmp x0, 4 + csel x11, x10, x11, LO + csel x15, x14, x15, LO + csel x12, x11, x12, LS + csel x19, x15, x19, LS + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with the biases. + ldp q11, q16, [x5, 0] + ldp q21, q26, [x5, 32] + mov v12.16b, v11.16b + mov v13.16b, v11.16b + mov v14.16b, v11.16b + mov v15.16b, v11.16b + mov v17.16b, v16.16b + mov v18.16b, v16.16b + mov v19.16b, v16.16b + mov v20.16b, v16.16b + mov v22.16b, v21.16b + mov v23.16b, v21.16b + mov v24.16b, v21.16b + mov v25.16b, v21.16b + mov v27.16b, v26.16b + mov v28.16b, v26.16b + mov v29.16b, v26.16b + mov v30.16b, v26.16b + add x5, x5, 64 + +inner_loop: + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 + ldr s5, [x11], 4 + ldr s6, [x12], 4 + ldp q7, q8, [x5], 32 + ldp q9, q10, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v7.4s, v3.s[0] + fmla v13.4s, v7.4s, v4.s[0] + fmla v14.4s, v7.4s, v5.s[0] + fmla v15.4s, v7.4s, v6.s[0] + fmla v16.4s, v8.4s, v2.s[0] + fmla v17.4s, v8.4s, v3.s[0] + subs x20, x20, 4 + fmla v18.4s, v8.4s, v4.s[0] + fmla v19.4s, v8.4s, v5.s[0] + fmla v20.4s, v8.4s, v6.s[0] + fmla v21.4s, v9.4s, v2.s[0] + fmla v22.4s, v9.4s, v3.s[0] + fmla v23.4s, v9.4s, v4.s[0] + fmla v24.4s, v9.4s, v5.s[0] + fmla v25.4s, v9.4s, v6.s[0] + fmla v26.4s, v10.4s, v2.s[0] + fmla v27.4s, v10.4s, v3.s[0] + fmla v28.4s, v10.4s, v4.s[0] + fmla v29.4s, v10.4s, v5.s[0] + fmla v30.4s, v10.4s, v6.s[0] + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmin v19.4s, v1.4s, v19.4s + fmin v20.4s, v1.4s, v20.4s + fmin v21.4s, v1.4s, v21.4s + fmin v22.4s, v1.4s, v22.4s + fmin v23.4s, v1.4s, v23.4s + fmin v24.4s, v1.4s, v24.4s + fmin v25.4s, v1.4s, v25.4s + fmin v26.4s, v1.4s, v26.4s + fmin v27.4s, v1.4s, v27.4s + fmin v28.4s, v1.4s, v28.4s + fmin v29.4s, v1.4s, v29.4s + fmin v30.4s, v1.4s, v30.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + fmax v19.4s, v0.4s, v19.4s + fmax v20.4s, v0.4s, v20.4s + fmax v21.4s, v0.4s, v21.4s + fmax v22.4s, v0.4s, v22.4s + fmax v23.4s, v0.4s, v23.4s + fmax v24.4s, v0.4s, v24.4s + fmax v25.4s, v0.4s, v25.4s + fmax v26.4s, v0.4s, v26.4s + fmax v27.4s, v0.4s, v27.4s + fmax v28.4s, v0.4s, v28.4s + fmax v29.4s, v0.4s, v29.4s + fmax v30.4s, v0.4s, v30.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q11, q16, [x6], 32 + stp q21, q26, [x6], 32 + stp q12, q17, [x13], 32 + stp q22, q27, [x13], 32 + stp q13, q18, [x14], 32 + stp q23, q28, [x14], 32 + stp q14, q19, [x15], 32 + stp q24, q29, [x15], 32 + stp q15, q20, [x19], 32 + stp q25, q30, [x19], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 + sub x11, x11, x2 + sub x12, x12, x2 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q11, q16, [x6], 32 + stp q12, q17, [x13], 32 + stp q13, q18, [x14], 32 + stp q14, q19, [x15], 32 + stp q15, q20, [x19], 32 + mov v11.16b, v21.16b + mov v16.16b, v26.16b + mov v12.16b, v22.16b + mov v17.16b, v27.16b + mov v13.16b, v23.16b + mov v18.16b, v28.16b + mov v14.16b, v24.16b + mov v19.16b, v29.16b + mov v15.16b, v25.16b + mov v20.16b, v30.16b + + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + str q12, [x13], 16 + str q13, [x14], 16 + str q14, [x15], 16 + str q15, [x19], 16 + mov v11.16b, v16.16b + mov v12.16b, v17.16b + mov v13.16b, v18.16b + mov v14.16b, v19.16b + mov v15.16b, v20.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + str d12, [x13], 8 + str d13, [x14], 8 + str d14, [x15], 8 + str d15, [x19], 8 + dup d11, v11.d[1] + dup d12, v12.d[1] + dup d13, v13.d[1] + dup d14, v14.d[1] + dup d15, v15.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + str s12, [x13] + str s13, [x14] + str s14, [x15] + str s15, [x19] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane diff --git a/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..ea0997bc8250 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,146 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Clamp a & c pointers if mr <= 4 + mov r12, r14 + add r12, r8 + mov r8, rbp + add r8, r11 + cmp rdi, 4 + cmovle r12, r14 + cmovle r8, rbp + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + vbroadcastss zmm5, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm5, zmm10 + vbroadcastss zmm6, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm6, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm7 + vmovups [r13], zmm8 + vmovups [rbx], zmm9 + vmovups [rbp], zmm14 + vmovups [r8], zmm15 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + add r8, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbp]{k1}, zmm14 + vmovups ZMMWORD PTR [r8]{k1}, zmm15 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..7d76f7794f58 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,179 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Clamp a & c pointers if mr <= 4 + mov r12, r14 + add r12, r8 + mov r8, rbp + add r8, r11 + cmp rdi, 4 + cmovle r12, r14 + cmovle r8, rbp + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm16, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm17, zmm16 + vmovaps zmm18, zmm16 + vmovaps zmm19, zmm16 + vmovaps zmm20, zmm16 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm16, zmm2, zmm11 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm17, zmm3, zmm11 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + vfmadd231ps zmm18, zmm4, zmm11 + vbroadcastss zmm5, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm5, zmm10 + vfmadd231ps zmm19, zmm5, zmm11 + vbroadcastss zmm6, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm6, zmm10 + vfmadd231ps zmm20, zmm6, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm16 + vmovups [r13], zmm8 + vmovups [r13 + 64], zmm17 + vmovups [rbx], zmm9 + vmovups [rbx + 64], zmm18 + vmovups [rbp], zmm14 + vmovups [rbp + 64], zmm19 + vmovups [r8], zmm15 + vmovups [r8 + 64], zmm20 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + add r8, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rbp]{k1}, zmm14 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [r8]{k1}, zmm15 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm20 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..8adfd1c1ca58 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,246 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Clamp a & c pointers if mr <= 4 + mov r12, r14 + add r12, r8 + mov r8, rbp + add r8, r11 + cmp rdi, 4 + cmovle r12, r14 + cmovle r8, rbp + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm16, [r9 + 64] + vmovaps zmm21, [r9 + 128] + vmovaps zmm26, [r9 + 192] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm17, zmm16 + vmovaps zmm18, zmm16 + vmovaps zmm19, zmm16 + vmovaps zmm20, zmm16 + vmovaps zmm22, zmm21 + vmovaps zmm23, zmm21 + vmovaps zmm24, zmm21 + vmovaps zmm25, zmm21 + vmovaps zmm27, zmm26 + vmovaps zmm28, zmm26 + vmovaps zmm29, zmm26 + vmovaps zmm30, zmm26 + add r9, 256 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm12, [r9 + 128] + vmovaps zmm13, [r9 + 192] + add r9, 256 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm16, zmm2, zmm11 + vfmadd231ps zmm21, zmm2, zmm12 + vfmadd231ps zmm26, zmm2, zmm13 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm17, zmm3, zmm11 + vfmadd231ps zmm22, zmm3, zmm12 + vfmadd231ps zmm27, zmm3, zmm13 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + vfmadd231ps zmm18, zmm4, zmm11 + vfmadd231ps zmm23, zmm4, zmm12 + vfmadd231ps zmm28, zmm4, zmm13 + vbroadcastss zmm5, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm5, zmm10 + vfmadd231ps zmm19, zmm5, zmm11 + vfmadd231ps zmm24, zmm5, zmm12 + vfmadd231ps zmm29, zmm5, zmm13 + vbroadcastss zmm6, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm6, zmm10 + vfmadd231ps zmm20, zmm6, zmm11 + vfmadd231ps zmm25, zmm6, zmm12 + vfmadd231ps zmm30, zmm6, zmm13 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vminps zmm28, zmm1, zmm28 + vminps zmm29, zmm1, zmm29 + vminps zmm30, zmm1, zmm30 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + vmaxps zmm28, zmm0, zmm28 + vmaxps zmm29, zmm0, zmm29 + vmaxps zmm30, zmm0, zmm30 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm16 + vmovups [r10 + 128], zmm21 + vmovups [r10 + 192], zmm26 + vmovups [r13], zmm8 + vmovups [r13 + 64], zmm17 + vmovups [r13 + 128], zmm22 + vmovups [r13 + 192], zmm27 + vmovups [rbx], zmm9 + vmovups [rbx + 64], zmm18 + vmovups [rbx + 128], zmm23 + vmovups [rbx + 192], zmm28 + vmovups [rbp], zmm14 + vmovups [rbp + 64], zmm19 + vmovups [rbp + 128], zmm24 + vmovups [rbp + 192], zmm29 + vmovups [r8], zmm15 + vmovups [r8 + 64], zmm20 + vmovups [r8 + 128], zmm25 + vmovups [r8 + 192], zmm30 + add r10, 256 + add r13, 256 + add rbx, 256 + add rbp, 256 + add r8, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm21 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm26 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm22 + vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm27 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm23 + vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm28 + vmovups ZMMWORD PTR [rbp]{k1}, zmm14 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [rbp + 128]{k3}, zmm24 + vmovups ZMMWORD PTR [rbp + 192]{k4}, zmm29 + vmovups ZMMWORD PTR [r8]{k1}, zmm15 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [r8 + 128]{k3}, zmm25 + vmovups ZMMWORD PTR [r8 + 192]{k4}, zmm30 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..1205fdbff755 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,165 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x11, x10, x4 + add x12, x11, x4 + add x13, x6, x7 + add x14, x13, x7 + add x15, x14, x7 + add x19, x15, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + + cmp x0, 4 + csel x11, x10, x11, LO + csel x15, x14, x15, LO + csel x12, x11, x12, LS + csel x19, x15, x19, LS + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with the biases. + ldp q11, q16, [x5, 0] + mov v12.16b, v11.16b + mov v13.16b, v11.16b + mov v14.16b, v11.16b + mov v15.16b, v11.16b + mov v17.16b, v16.16b + mov v18.16b, v16.16b + mov v19.16b, v16.16b + mov v20.16b, v16.16b + add x5, x5, 32 + +inner_loop: + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 + ldr s5, [x11], 4 + ldr s6, [x12], 4 + ldp q7, q8, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v7.4s, v3.s[0] + fmla v13.4s, v7.4s, v4.s[0] + fmla v14.4s, v7.4s, v5.s[0] + fmla v15.4s, v7.4s, v6.s[0] + fmla v16.4s, v8.4s, v2.s[0] + fmla v17.4s, v8.4s, v3.s[0] + fmla v18.4s, v8.4s, v4.s[0] + fmla v19.4s, v8.4s, v5.s[0] + fmla v20.4s, v8.4s, v6.s[0] + subs x20, x20, 4 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmin v19.4s, v1.4s, v19.4s + fmin v20.4s, v1.4s, v20.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + fmax v19.4s, v0.4s, v19.4s + fmax v20.4s, v0.4s, v20.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q11, q16, [x6], 32 + stp q12, q17, [x13], 32 + stp q13, q18, [x14], 32 + stp q14, q19, [x15], 32 + stp q15, q20, [x19], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 + sub x11, x11, x2 + sub x12, x12, x2 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + str q12, [x13], 16 + str q13, [x14], 16 + str q14, [x15], 16 + str q15, [x19], 16 + mov v11.16b, v16.16b + mov v12.16b, v17.16b + mov v13.16b, v18.16b + mov v14.16b, v19.16b + mov v15.16b, v20.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + str d12, [x13], 8 + str d13, [x14], 8 + str d14, [x15], 8 + str d15, [x19], 8 + dup d11, v11.d[1] + dup d12, v12.d[1] + dup d13, v13.d[1] + dup d14, v14.d[1] + dup d15, v15.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + str s12, [x13] + str s13, [x14] + str s14, [x15] + str s15, [x19] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..d44a42ccb3ad --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,206 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm7 + vmovups [rax], zmm8 + vmovups [r15], zmm9 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..9a8a091a14c9 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,245 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm17, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm18, zmm17 + vmovaps zmm19, zmm17 + vmovaps zmm20, zmm17 + vmovaps zmm21, zmm17 + vmovaps zmm22, zmm17 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm17, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm18, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm19, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm20, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm21, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm22, zmm2, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm7 + vmovups [rsi + 64], zmm17 + vmovups [rax], zmm8 + vmovups [rax + 64], zmm18 + vmovups [r15], zmm9 + vmovups [r15 + 64], zmm19 + vmovups [r14], zmm14 + vmovups [r14 + 64], zmm20 + vmovups [r12], zmm15 + vmovups [r12 + 64], zmm21 + vmovups [r10], zmm16 + vmovups [r10 + 64], zmm22 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm22 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..361c5d1ab3a9 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,229 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm7 + vmovups [rax], zmm8 + vmovups [r15], zmm9 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + vmovups [r13], zmm17 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..9354a90d8af1 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,274 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm18, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm19, zmm18 + vmovaps zmm20, zmm18 + vmovaps zmm21, zmm18 + vmovaps zmm22, zmm18 + vmovaps zmm23, zmm18 + vmovaps zmm24, zmm18 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm18, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm19, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm20, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm21, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm22, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm23, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vfmadd231ps zmm24, zmm2, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm7 + vmovups [rsi + 64], zmm18 + vmovups [rax], zmm8 + vmovups [rax + 64], zmm19 + vmovups [r15], zmm9 + vmovups [r15 + 64], zmm20 + vmovups [r14], zmm14 + vmovups [r14 + 64], zmm21 + vmovups [r12], zmm15 + vmovups [r12 + 64], zmm22 + vmovups [r10], zmm16 + vmovups [r10 + 64], zmm23 + vmovups [r13], zmm17 + vmovups [r13 + 64], zmm24 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm24 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..7233c24350f2 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,252 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm18, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm7 + vmovups [rax], zmm8 + vmovups [r15], zmm9 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + vmovups [r13], zmm17 + vmovups [rbx], zmm18 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..e4633f2fcfe1 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,303 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm19, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm18, zmm7 + vmovaps zmm20, zmm19 + vmovaps zmm21, zmm19 + vmovaps zmm22, zmm19 + vmovaps zmm23, zmm19 + vmovaps zmm24, zmm19 + vmovaps zmm25, zmm19 + vmovaps zmm26, zmm19 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm19, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm20, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm21, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm22, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm23, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm24, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vfmadd231ps zmm25, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm10 + vfmadd231ps zmm26, zmm2, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm7 + vmovups [rsi + 64], zmm19 + vmovups [rax], zmm8 + vmovups [rax + 64], zmm20 + vmovups [r15], zmm9 + vmovups [r15 + 64], zmm21 + vmovups [r14], zmm14 + vmovups [r14 + 64], zmm22 + vmovups [r12], zmm15 + vmovups [r12 + 64], zmm23 + vmovups [r10], zmm16 + vmovups [r10 + 64], zmm24 + vmovups [r13], zmm17 + vmovups [r13 + 64], zmm25 + vmovups [rbx], zmm18 + vmovups [rbx + 64], zmm26 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + add rbx, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm25 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm26 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..c239863a9c7f --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,275 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm18, zmm7 + vmovaps zmm19, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rbp + r11] + vfmadd231ps zmm19, zmm2, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm7 + vmovups [rax], zmm8 + vmovups [r15], zmm9 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + vmovups [r13], zmm17 + vmovups [rbx], zmm18 + vmovups [rbp], zmm19 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbp]{k1}, zmm19 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..c65ab67655fd --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,332 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm20, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm18, zmm7 + vmovaps zmm19, zmm7 + vmovaps zmm21, zmm20 + vmovaps zmm22, zmm20 + vmovaps zmm23, zmm20 + vmovaps zmm24, zmm20 + vmovaps zmm25, zmm20 + vmovaps zmm26, zmm20 + vmovaps zmm27, zmm20 + vmovaps zmm28, zmm20 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm20, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm21, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm22, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm23, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm24, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm25, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vfmadd231ps zmm26, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm10 + vfmadd231ps zmm27, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rbp + r11] + vfmadd231ps zmm19, zmm2, zmm10 + vfmadd231ps zmm28, zmm2, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vminps zmm28, zmm1, zmm28 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + vmaxps zmm28, zmm0, zmm28 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm7 + vmovups [rsi + 64], zmm20 + vmovups [rax], zmm8 + vmovups [rax + 64], zmm21 + vmovups [r15], zmm9 + vmovups [r15 + 64], zmm22 + vmovups [r14], zmm14 + vmovups [r14 + 64], zmm23 + vmovups [r12], zmm15 + vmovups [r12 + 64], zmm24 + vmovups [r10], zmm16 + vmovups [r10 + 64], zmm25 + vmovups [r13], zmm17 + vmovups [r13 + 64], zmm26 + vmovups [rbx], zmm18 + vmovups [rbx + 64], zmm27 + vmovups [rbp], zmm19 + vmovups [rbp + 64], zmm28 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm25 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm26 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm27 + vmovups ZMMWORD PTR [rbp]{k1}, zmm19 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm28 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..21ca0f84253b --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,375 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 1104 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 272], rax + mov [rsp - 280], r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 1192] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + mov edi, [r11 + 56] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 912], zmm6 + mov edi, [r11 + 64] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 976], zmm6 + mov edi, [r11 + 72] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 1040], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + mov r8, [rsp - 272] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm6, ZMMWORD PTR [rsp + 912] + vpmulld zmm20, zmm6, ZMMWORD PTR [rsp + 976] + vpmulld zmm21, zmm6, ZMMWORD PTR [rsp + 1040] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpbroadcastd zmm2, [rbx + r11] + vpdpbusd zmm19, zmm2, zmm6 + vpbroadcastd zmm2, [rbp + r11] + vpdpbusd zmm20, zmm2, zmm6 + vpbroadcastd zmm2, [r8 + r11] + vpdpbusd zmm21, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + # Load quantization_params pointer from stack + mov r11, [rsp + 1192] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 60]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 68]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 76]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm10, zmm6 + vfmadd213ps zmm20, zmm10, zmm6 + vfmadd213ps zmm21, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + mov r8, [rsp - 280] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm5 + vmovups [rax], zmm12 + vmovups [r15], zmm14 + vmovups [r14], zmm15 + vmovups [r12], zmm16 + vmovups [r10], zmm17 + vmovups [r13], zmm18 + vmovups [rbx], zmm19 + vmovups [rbp], zmm20 + vmovups [r8], zmm21 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + add r8, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + mov [rsp - 280], r8 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [rbx]{k1}, zmm19 + vmovups ZMMWORD PTR [rbp]{k1}, zmm20 + vmovups ZMMWORD PTR [r8]{k1}, zmm21 + +return: + add rsp, 1104 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..a6ae6619c078 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,471 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 1104 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 272], rax + mov [rsp - 280], r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 1192] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + mov edi, [r11 + 56] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 912], zmm6 + mov edi, [r11 + 64] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 976], zmm6 + mov edi, [r11 + 72] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 1040], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + mov r8, [rsp - 272] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm6, ZMMWORD PTR [rsp + 912] + vpmulld zmm20, zmm6, ZMMWORD PTR [rsp + 976] + vpmulld zmm21, zmm6, ZMMWORD PTR [rsp + 1040] + vpmulld zmm22, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm23, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm24, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm25, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm26, zmm7, ZMMWORD PTR [rsp + 720] + vpmulld zmm27, zmm7, ZMMWORD PTR [rsp + 784] + vpmulld zmm28, zmm7, ZMMWORD PTR [rsp + 848] + vpmulld zmm29, zmm7, ZMMWORD PTR [rsp + 912] + vpmulld zmm30, zmm7, ZMMWORD PTR [rsp + 976] + vpmulld zmm4, zmm7, ZMMWORD PTR [rsp + 1040] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm22, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm23, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm24, zmm2, zmm7 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm25, zmm2, zmm7 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpdpbusd zmm26, zmm2, zmm7 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpdpbusd zmm27, zmm2, zmm7 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpdpbusd zmm28, zmm2, zmm7 + vpbroadcastd zmm2, [rbx + r11] + vpdpbusd zmm19, zmm2, zmm6 + vpdpbusd zmm29, zmm2, zmm7 + vpbroadcastd zmm2, [rbp + r11] + vpdpbusd zmm20, zmm2, zmm6 + vpdpbusd zmm30, zmm2, zmm7 + vpbroadcastd zmm2, [r8 + r11] + vpdpbusd zmm21, zmm2, zmm6 + vpdpbusd zmm4, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + vcvtdq2ps zmm24, zmm24 + vcvtdq2ps zmm25, zmm25 + vcvtdq2ps zmm26, zmm26 + vcvtdq2ps zmm27, zmm27 + vcvtdq2ps zmm28, zmm28 + vcvtdq2ps zmm29, zmm29 + vcvtdq2ps zmm30, zmm30 + vcvtdq2ps zmm4, zmm4 + # Load quantization_params pointer from stack + mov r11, [rsp + 1192] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 60]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 68]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 76]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 4]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 12]{1to16} + vmulps zmm24, zmm24, DWORD PTR [r11 + 20]{1to16} + vmulps zmm25, zmm25, DWORD PTR [r11 + 28]{1to16} + vmulps zmm26, zmm26, DWORD PTR [r11 + 36]{1to16} + vmulps zmm27, zmm27, DWORD PTR [r11 + 44]{1to16} + vmulps zmm28, zmm28, DWORD PTR [r11 + 52]{1to16} + vmulps zmm29, zmm29, DWORD PTR [r11 + 60]{1to16} + vmulps zmm30, zmm30, DWORD PTR [r11 + 68]{1to16} + vmulps zmm4, zmm4, DWORD PTR [r11 + 76]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm10, zmm6 + vfmadd213ps zmm20, zmm10, zmm6 + vfmadd213ps zmm21, zmm10, zmm6 + vfmadd213ps zmm22, zmm11, zmm7 + vfmadd213ps zmm23, zmm11, zmm7 + vfmadd213ps zmm24, zmm11, zmm7 + vfmadd213ps zmm25, zmm11, zmm7 + vfmadd213ps zmm26, zmm11, zmm7 + vfmadd213ps zmm27, zmm11, zmm7 + vfmadd213ps zmm28, zmm11, zmm7 + vfmadd213ps zmm29, zmm11, zmm7 + vfmadd213ps zmm30, zmm11, zmm7 + vfmadd213ps zmm4, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vminps zmm28, zmm1, zmm28 + vminps zmm29, zmm1, zmm29 + vminps zmm30, zmm1, zmm30 + vminps zmm4, zmm1, zmm4 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + vmaxps zmm28, zmm0, zmm28 + vmaxps zmm29, zmm0, zmm29 + vmaxps zmm30, zmm0, zmm30 + vmaxps zmm4, zmm0, zmm4 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + mov r8, [rsp - 280] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm5 + vmovups [rsi + 64], zmm22 + vmovups [rax], zmm12 + vmovups [rax + 64], zmm23 + vmovups [r15], zmm14 + vmovups [r15 + 64], zmm24 + vmovups [r14], zmm15 + vmovups [r14 + 64], zmm25 + vmovups [r12], zmm16 + vmovups [r12 + 64], zmm26 + vmovups [r10], zmm17 + vmovups [r10 + 64], zmm27 + vmovups [r13], zmm18 + vmovups [r13 + 64], zmm28 + vmovups [rbx], zmm19 + vmovups [rbx + 64], zmm29 + vmovups [rbp], zmm20 + vmovups [rbp + 64], zmm30 + vmovups [r8], zmm21 + vmovups [r8 + 64], zmm4 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + add r8, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + mov [rsp - 280], r8 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm25 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm26 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm27 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm28 + vmovups ZMMWORD PTR [rbx]{k1}, zmm19 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm29 + vmovups ZMMWORD PTR [rbp]{k1}, zmm20 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm30 + vmovups ZMMWORD PTR [r8]{k1}, zmm21 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm4 + +return: + add rsp, 1104 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..4d66ebfc7d28 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,404 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 1168 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 272], rax + mov [rsp - 280], r13 + + # Clamp a & c pointers if mr <= 10 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 10 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 288], rsi + mov [rsp - 296], r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 1256] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + mov edi, [r11 + 56] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 912], zmm6 + mov edi, [r11 + 64] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 976], zmm6 + mov edi, [r11 + 72] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 1040], zmm6 + mov edi, [r11 + 80] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 1104], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + mov r8, [rsp - 272] + mov rdi, [rsp - 288] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm6, ZMMWORD PTR [rsp + 912] + vpmulld zmm20, zmm6, ZMMWORD PTR [rsp + 976] + vpmulld zmm21, zmm6, ZMMWORD PTR [rsp + 1040] + vpmulld zmm22, zmm6, ZMMWORD PTR [rsp + 1104] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpbroadcastd zmm2, [rbx + r11] + vpdpbusd zmm19, zmm2, zmm6 + vpbroadcastd zmm2, [rbp + r11] + vpdpbusd zmm20, zmm2, zmm6 + vpbroadcastd zmm2, [r8 + r11] + vpdpbusd zmm21, zmm2, zmm6 + vpbroadcastd zmm2, [rdi + r11] + vpdpbusd zmm22, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + # Load quantization_params pointer from stack + mov r11, [rsp + 1256] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 60]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 68]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 76]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 84]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm10, zmm6 + vfmadd213ps zmm20, zmm10, zmm6 + vfmadd213ps zmm21, zmm10, zmm6 + vfmadd213ps zmm22, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + mov r8, [rsp - 280] + mov rdi, [rsp - 296] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm5 + vmovups [rax], zmm12 + vmovups [r15], zmm14 + vmovups [r14], zmm15 + vmovups [r12], zmm16 + vmovups [r10], zmm17 + vmovups [r13], zmm18 + vmovups [rbx], zmm19 + vmovups [rbp], zmm20 + vmovups [r8], zmm21 + vmovups [rdi], zmm22 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + add r8, 64 + add rdi, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + mov [rsp - 280], r8 + mov [rsp - 296], rdi + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [rbx]{k1}, zmm19 + vmovups ZMMWORD PTR [rbp]{k1}, zmm20 + vmovups ZMMWORD PTR [r8]{k1}, zmm21 + vmovups ZMMWORD PTR [rdi]{k1}, zmm22 + +return: + add rsp, 1168 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..a8bf41586bc7 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,509 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 1168 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 272], rax + mov [rsp - 280], r13 + + # Clamp a & c pointers if mr <= 10 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 10 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 288], rsi + mov [rsp - 296], r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 1256] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + mov edi, [r11 + 56] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 912], zmm6 + mov edi, [r11 + 64] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 976], zmm6 + mov edi, [r11 + 72] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 1040], zmm6 + mov edi, [r11 + 80] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 1104], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + mov r8, [rsp - 272] + mov rdi, [rsp - 288] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm6, ZMMWORD PTR [rsp + 912] + vpmulld zmm20, zmm6, ZMMWORD PTR [rsp + 976] + vpmulld zmm21, zmm6, ZMMWORD PTR [rsp + 1040] + vpmulld zmm22, zmm6, ZMMWORD PTR [rsp + 1104] + vpmulld zmm23, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm24, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm25, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm26, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm27, zmm7, ZMMWORD PTR [rsp + 720] + vpmulld zmm28, zmm7, ZMMWORD PTR [rsp + 784] + vpmulld zmm29, zmm7, ZMMWORD PTR [rsp + 848] + vpmulld zmm30, zmm7, ZMMWORD PTR [rsp + 912] + vpmulld zmm4, zmm7, ZMMWORD PTR [rsp + 976] + vpmulld zmm8, zmm7, ZMMWORD PTR [rsp + 1040] + vpmulld zmm9, zmm7, ZMMWORD PTR [rsp + 1104] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm23, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm24, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm25, zmm2, zmm7 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm26, zmm2, zmm7 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpdpbusd zmm27, zmm2, zmm7 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpdpbusd zmm28, zmm2, zmm7 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpdpbusd zmm29, zmm2, zmm7 + vpbroadcastd zmm2, [rbx + r11] + vpdpbusd zmm19, zmm2, zmm6 + vpdpbusd zmm30, zmm2, zmm7 + vpbroadcastd zmm2, [rbp + r11] + vpdpbusd zmm20, zmm2, zmm6 + vpdpbusd zmm4, zmm2, zmm7 + vpbroadcastd zmm2, [r8 + r11] + vpdpbusd zmm21, zmm2, zmm6 + vpdpbusd zmm8, zmm2, zmm7 + vpbroadcastd zmm2, [rdi + r11] + vpdpbusd zmm22, zmm2, zmm6 + vpdpbusd zmm9, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + vcvtdq2ps zmm24, zmm24 + vcvtdq2ps zmm25, zmm25 + vcvtdq2ps zmm26, zmm26 + vcvtdq2ps zmm27, zmm27 + vcvtdq2ps zmm28, zmm28 + vcvtdq2ps zmm29, zmm29 + vcvtdq2ps zmm30, zmm30 + vcvtdq2ps zmm4, zmm4 + vcvtdq2ps zmm8, zmm8 + vcvtdq2ps zmm9, zmm9 + # Load quantization_params pointer from stack + mov r11, [rsp + 1256] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 60]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 68]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 76]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 84]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 4]{1to16} + vmulps zmm24, zmm24, DWORD PTR [r11 + 12]{1to16} + vmulps zmm25, zmm25, DWORD PTR [r11 + 20]{1to16} + vmulps zmm26, zmm26, DWORD PTR [r11 + 28]{1to16} + vmulps zmm27, zmm27, DWORD PTR [r11 + 36]{1to16} + vmulps zmm28, zmm28, DWORD PTR [r11 + 44]{1to16} + vmulps zmm29, zmm29, DWORD PTR [r11 + 52]{1to16} + vmulps zmm30, zmm30, DWORD PTR [r11 + 60]{1to16} + vmulps zmm4, zmm4, DWORD PTR [r11 + 68]{1to16} + vmulps zmm8, zmm8, DWORD PTR [r11 + 76]{1to16} + vmulps zmm9, zmm9, DWORD PTR [r11 + 84]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm10, zmm6 + vfmadd213ps zmm20, zmm10, zmm6 + vfmadd213ps zmm21, zmm10, zmm6 + vfmadd213ps zmm22, zmm10, zmm6 + vfmadd213ps zmm23, zmm11, zmm7 + vfmadd213ps zmm24, zmm11, zmm7 + vfmadd213ps zmm25, zmm11, zmm7 + vfmadd213ps zmm26, zmm11, zmm7 + vfmadd213ps zmm27, zmm11, zmm7 + vfmadd213ps zmm28, zmm11, zmm7 + vfmadd213ps zmm29, zmm11, zmm7 + vfmadd213ps zmm30, zmm11, zmm7 + vfmadd213ps zmm4, zmm11, zmm7 + vfmadd213ps zmm8, zmm11, zmm7 + vfmadd213ps zmm9, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vminps zmm28, zmm1, zmm28 + vminps zmm29, zmm1, zmm29 + vminps zmm30, zmm1, zmm30 + vminps zmm4, zmm1, zmm4 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + vmaxps zmm28, zmm0, zmm28 + vmaxps zmm29, zmm0, zmm29 + vmaxps zmm30, zmm0, zmm30 + vmaxps zmm4, zmm0, zmm4 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + mov r8, [rsp - 280] + mov rdi, [rsp - 296] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm5 + vmovups [rsi + 64], zmm23 + vmovups [rax], zmm12 + vmovups [rax + 64], zmm24 + vmovups [r15], zmm14 + vmovups [r15 + 64], zmm25 + vmovups [r14], zmm15 + vmovups [r14 + 64], zmm26 + vmovups [r12], zmm16 + vmovups [r12 + 64], zmm27 + vmovups [r10], zmm17 + vmovups [r10 + 64], zmm28 + vmovups [r13], zmm18 + vmovups [r13 + 64], zmm29 + vmovups [rbx], zmm19 + vmovups [rbx + 64], zmm30 + vmovups [rbp], zmm20 + vmovups [rbp + 64], zmm4 + vmovups [r8], zmm21 + vmovups [r8 + 64], zmm8 + vmovups [rdi], zmm22 + vmovups [rdi + 64], zmm9 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + add r8, 128 + add rdi, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + mov [rsp - 280], r8 + mov [rsp - 296], rdi + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm25 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm26 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm27 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm28 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm29 + vmovups ZMMWORD PTR [rbx]{k1}, zmm19 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm30 + vmovups ZMMWORD PTR [rbp]{k1}, zmm20 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm4 + vmovups ZMMWORD PTR [r8]{k1}, zmm21 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm8 + vmovups ZMMWORD PTR [rdi]{k1}, zmm22 + vmovups ZMMWORD PTR [rdi + 64]{k2}, zmm9 + +return: + add rsp, 1168 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld32.S new file mode 100644 index 000000000000..85a7c8a52be1 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld32.S @@ -0,0 +1,135 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + ldr x24, [sp, 16] + # Round kc up to channels. + add x2, x2, #3 + and x2, x2, #0xFFFFFFFFFFFFFFFC + + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with k_sum * input zero point. + ldr q10, [x24] + ldp q2, q3, [x5, 0] + ldp q4, q5, [x5, 32] + mul v12.4s, v2.4s, v10.s[0] + mul v13.4s, v3.4s, v10.s[0] + mul v14.4s, v4.4s, v10.s[0] + mul v15.4s, v5.4s, v10.s[0] + add x5, x5, 64 + +inner_loop: + ldr s2, [x3], 4 + ldp q6, q7, [x5], 32 + ldp q8, q9, [x5], 32 + sdot v12.4s, v6.16b, v2.4b[0] + sdot v13.4s, v7.16b, v2.4b[0] + sdot v14.4s, v8.16b, v2.4b[0] + sdot v15.4s, v9.16b, v2.4b[0] + subs x20, x20, 4 + bne inner_loop + + # Convert from int32 to float. + scvtf v12.4s, v12.4s + scvtf v13.4s, v13.4s + scvtf v14.4s, v14.4s + scvtf v15.4s, v15.4s + # Multiply by input scale. + fmul v12.4s, v12.4s, v10.s[1] + fmul v13.4s, v13.4s, v10.s[1] + fmul v14.4s, v14.4s, v10.s[1] + fmul v15.4s, v15.4s, v10.s[1] + # Load weights scale. + ldp q2, q3, [x5, 0] + ldp q4, q5, [x5, 32] + add x5, x5, 64 + # Load biases. + ldp q6, q7, [x5, 0] + ldp q8, q9, [x5, 32] + add x5, x5, 64 + # Multiply by weight's scale. + fmul v12.4s, v12.4s, v2.4s + fmul v13.4s, v13.4s, v3.4s + fmul v14.4s, v14.4s, v4.4s + fmul v15.4s, v15.4s, v5.4s + # Add bias. + fadd v12.4s, v12.4s, v6.4s + fadd v13.4s, v13.4s, v7.4s + fadd v14.4s, v14.4s, v8.4s + fadd v15.4s, v15.4s, v9.4s + # Min/max clamping.. + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q12, q13, [x6], 32 + stp q14, q15, [x6], 32 + sub x3, x3, x2 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q12, q13, [x6], 32 + mov v12.16b, v14.16b + mov v13.16b, v15.16b + + +tail_4: + tbz x1, 2, tail_2 + str q12, [x6], 16 + mov v12.16b, v13.16b + + +tail_2: + tbz x1, 1, tail_1 + str d12, [x6], 8 + dup d12, v12.d[1] + + +tail_1: + tbz x1, 0, return + str s12, [x6] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_lane \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..ac491a5dde15 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,101 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 528 + + # Load quantization params pointer from stack + mov r11, [rsp + 616] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + # Load quantization_params pointer from stack + mov r11, [rsp + 616] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vmaxps zmm5, zmm0, zmm5 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm5 + add r10, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + +return: + add rsp, 528 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..0bef4bbf8834 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,116 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 528 + + # Load quantization params pointer from stack + mov r11, [rsp + 616] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm7, ZMMWORD PTR [rsp + 464] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm12, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + # Load quantization_params pointer from stack + mov r11, [rsp + 616] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 4]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm12 + add r10, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm12 + +return: + add rsp, 528 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..93737fb82075 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,147 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 528 + + # Load quantization params pointer from stack + mov r11, [rsp + 616] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm14, zmm8, ZMMWORD PTR [rsp + 464] + vpmulld zmm15, zmm9, ZMMWORD PTR [rsp + 464] + add r9, 256 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm12, zmm2, zmm7 + vpdpbusd zmm14, zmm2, zmm8 + vpdpbusd zmm15, zmm2, zmm9 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + # Load quantization_params pointer from stack + mov r11, [rsp + 616] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 4]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 4]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 4]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm2, [r9 + 128] + vmovaps zmm3, [r9 + 192] + add r9, 256 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm11, zmm7 + vfmadd213ps zmm14, zmm2, zmm8 + vfmadd213ps zmm15, zmm3, zmm9 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm12 + vmovups [r10 + 128], zmm14 + vmovups [r10 + 192], zmm15 + add r10, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm12 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm14 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm15 + +return: + add rsp, 528 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld32.S new file mode 100644 index 000000000000..f74e5b726255 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld32.S @@ -0,0 +1,107 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + ldr x24, [sp, 16] + # Round kc up to channels. + add x2, x2, #3 + and x2, x2, #0xFFFFFFFFFFFFFFFC + + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with k_sum * input zero point. + ldr q10, [x24] + ldp q2, q3, [x5, 0] + mul v12.4s, v2.4s, v10.s[0] + mul v13.4s, v3.4s, v10.s[0] + add x5, x5, 32 + +inner_loop: + ldr s2, [x3], 4 + ldp q6, q7, [x5], 32 + sdot v12.4s, v6.16b, v2.4b[0] + sdot v13.4s, v7.16b, v2.4b[0] + subs x20, x20, 4 + bne inner_loop + + # Convert from int32 to float. + scvtf v12.4s, v12.4s + scvtf v13.4s, v13.4s + # Multiply by input scale. + fmul v12.4s, v12.4s, v10.s[1] + fmul v13.4s, v13.4s, v10.s[1] + # Load weights scale. + ldp q2, q3, [x5, 0] + add x5, x5, 32 + # Load biases. + ldp q6, q7, [x5, 0] + add x5, x5, 32 + # Multiply by weight's scale. + fmul v12.4s, v12.4s, v2.4s + fmul v13.4s, v13.4s, v3.4s + # Add bias. + fadd v12.4s, v12.4s, v6.4s + fadd v13.4s, v13.4s, v7.4s + # Min/max clamping.. + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q12, q13, [x6], 32 + sub x3, x3, x2 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q12, [x6], 16 + mov v12.16b, v13.16b + + +tail_2: + tbz x1, 1, tail_1 + str d12, [x6], 8 + dup d12, v12.d[1] + + +tail_1: + tbz x1, 0, return + str s12, [x6] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_lane \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld32.S new file mode 100644 index 000000000000..466642b8f274 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld32.S @@ -0,0 +1,186 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + ldr x24, [sp, 16] + # Round kc up to channels. + add x2, x2, #3 + and x2, x2, #0xFFFFFFFFFFFFFFFC + + # Setup and alias a & c pointers. + add x9, x3, x4 + add x13, x6, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with k_sum * input zero point. + ldp q10, q11, [x24] + ldp q2, q3, [x5, 0] + ldp q4, q5, [x5, 32] + mul v12.4s, v2.4s, v10.s[0] + mul v13.4s, v2.4s, v10.s[2] + mul v14.4s, v3.4s, v10.s[0] + mul v15.4s, v3.4s, v10.s[2] + mul v16.4s, v4.4s, v10.s[0] + mul v17.4s, v4.4s, v10.s[2] + mul v18.4s, v5.4s, v10.s[0] + mul v19.4s, v5.4s, v10.s[2] + add x5, x5, 64 + +inner_loop: + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldp q6, q7, [x5], 32 + ldp q8, q9, [x5], 32 + sdot v12.4s, v6.16b, v2.4b[0] + sdot v13.4s, v6.16b, v3.4b[0] + sdot v14.4s, v7.16b, v2.4b[0] + sdot v15.4s, v7.16b, v3.4b[0] + sdot v16.4s, v8.16b, v2.4b[0] + sdot v17.4s, v8.16b, v3.4b[0] + sdot v18.4s, v9.16b, v2.4b[0] + sdot v19.4s, v9.16b, v3.4b[0] + subs x20, x20, 4 + bne inner_loop + + # Convert from int32 to float. + scvtf v12.4s, v12.4s + scvtf v13.4s, v13.4s + scvtf v14.4s, v14.4s + scvtf v15.4s, v15.4s + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + # Multiply by input scale. + fmul v12.4s, v12.4s, v10.s[1] + fmul v13.4s, v13.4s, v10.s[3] + fmul v14.4s, v14.4s, v10.s[1] + fmul v15.4s, v15.4s, v10.s[3] + fmul v16.4s, v16.4s, v10.s[1] + fmul v17.4s, v17.4s, v10.s[3] + fmul v18.4s, v18.4s, v10.s[1] + fmul v19.4s, v19.4s, v10.s[3] + # Load weights scale. + ldp q2, q3, [x5, 0] + ldp q4, q5, [x5, 32] + add x5, x5, 64 + # Load biases. + ldp q6, q7, [x5, 0] + ldp q8, q9, [x5, 32] + add x5, x5, 64 + # Multiply by weight's scale. + fmul v12.4s, v12.4s, v2.4s + fmul v13.4s, v13.4s, v2.4s + fmul v14.4s, v14.4s, v3.4s + fmul v15.4s, v15.4s, v3.4s + fmul v16.4s, v16.4s, v4.4s + fmul v17.4s, v17.4s, v4.4s + fmul v18.4s, v18.4s, v5.4s + fmul v19.4s, v19.4s, v5.4s + # Add bias. + fadd v12.4s, v12.4s, v6.4s + fadd v13.4s, v13.4s, v6.4s + fadd v14.4s, v14.4s, v7.4s + fadd v15.4s, v15.4s, v7.4s + fadd v16.4s, v16.4s, v8.4s + fadd v17.4s, v17.4s, v8.4s + fadd v18.4s, v18.4s, v9.4s + fadd v19.4s, v19.4s, v9.4s + # Min/max clamping.. + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmin v19.4s, v1.4s, v19.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + fmax v19.4s, v0.4s, v19.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q12, q14, [x6], 32 + stp q16, q18, [x6], 32 + stp q13, q15, [x13], 32 + stp q17, q19, [x13], 32 + sub x3, x3, x2 + sub x9, x9, x2 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q12, q14, [x6], 32 + stp q13, q15, [x13], 32 + mov v12.16b, v16.16b + mov v14.16b, v18.16b + mov v13.16b, v17.16b + mov v15.16b, v19.16b + + +tail_4: + tbz x1, 2, tail_2 + str q12, [x6], 16 + str q13, [x13], 16 + mov v12.16b, v14.16b + mov v13.16b, v15.16b + + +tail_2: + tbz x1, 1, tail_1 + str d12, [x6], 8 + str d13, [x13], 8 + dup d12, v12.d[1] + dup d13, v13.d[1] + + +tail_1: + tbz x1, 0, return + str s12, [x6] + str s13, [x13] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_lane \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..a1d36fcaa02a --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,124 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 592 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 680] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + # Load quantization_params pointer from stack + mov r11, [rsp + 680] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm5 + vmovups [r13], zmm12 + add r10, 64 + add r13, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + +return: + add rsp, 592 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..e925e817701d --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,148 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 592 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 680] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm15, zmm7, ZMMWORD PTR [rsp + 528] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm14, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm15, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + # Load quantization_params pointer from stack + mov r11, [rsp + 680] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 4]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 12]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm11, zmm7 + vfmadd213ps zmm15, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm14 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm15 + add r10, 128 + add r13, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm14 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm15 + +return: + add rsp, 592 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..c7f7e099b888 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,197 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 592 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 680] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm15, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm16, zmm8, ZMMWORD PTR [rsp + 464] + vpmulld zmm17, zmm8, ZMMWORD PTR [rsp + 528] + vpmulld zmm18, zmm9, ZMMWORD PTR [rsp + 464] + vpmulld zmm19, zmm9, ZMMWORD PTR [rsp + 528] + add r9, 256 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm14, zmm2, zmm7 + vpdpbusd zmm16, zmm2, zmm8 + vpdpbusd zmm18, zmm2, zmm9 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm15, zmm2, zmm7 + vpdpbusd zmm17, zmm2, zmm8 + vpdpbusd zmm19, zmm2, zmm9 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + # Load quantization_params pointer from stack + mov r11, [rsp + 680] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 4]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 12]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 4]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 12]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 4]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 12]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm2, [r9 + 128] + vmovaps zmm3, [r9 + 192] + add r9, 256 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm11, zmm7 + vfmadd213ps zmm15, zmm11, zmm7 + vfmadd213ps zmm16, zmm2, zmm8 + vfmadd213ps zmm17, zmm2, zmm8 + vfmadd213ps zmm18, zmm3, zmm9 + vfmadd213ps zmm19, zmm3, zmm9 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm14 + vmovups [r10 + 128], zmm16 + vmovups [r10 + 192], zmm18 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm15 + vmovups [r13 + 128], zmm17 + vmovups [r13 + 192], zmm19 + add r10, 256 + add r13, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm14 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm16 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm18 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm15 + vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm17 + vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm19 + +return: + add rsp, 592 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld32.S new file mode 100644 index 000000000000..43b2d05e26e4 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld32.S @@ -0,0 +1,138 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + ldr x24, [sp, 16] + # Round kc up to channels. + add x2, x2, #3 + and x2, x2, #0xFFFFFFFFFFFFFFFC + + # Setup and alias a & c pointers. + add x9, x3, x4 + add x13, x6, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with k_sum * input zero point. + ldp q10, q11, [x24] + ldp q2, q3, [x5, 0] + mul v12.4s, v2.4s, v10.s[0] + mul v13.4s, v2.4s, v10.s[2] + mul v14.4s, v3.4s, v10.s[0] + mul v15.4s, v3.4s, v10.s[2] + add x5, x5, 32 + +inner_loop: + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldp q6, q7, [x5], 32 + sdot v12.4s, v6.16b, v2.4b[0] + sdot v13.4s, v6.16b, v3.4b[0] + sdot v14.4s, v7.16b, v2.4b[0] + sdot v15.4s, v7.16b, v3.4b[0] + subs x20, x20, 4 + bne inner_loop + + # Convert from int32 to float. + scvtf v12.4s, v12.4s + scvtf v13.4s, v13.4s + scvtf v14.4s, v14.4s + scvtf v15.4s, v15.4s + # Multiply by input scale. + fmul v12.4s, v12.4s, v10.s[1] + fmul v13.4s, v13.4s, v10.s[3] + fmul v14.4s, v14.4s, v10.s[1] + fmul v15.4s, v15.4s, v10.s[3] + # Load weights scale. + ldp q2, q3, [x5, 0] + add x5, x5, 32 + # Load biases. + ldp q6, q7, [x5, 0] + add x5, x5, 32 + # Multiply by weight's scale. + fmul v12.4s, v12.4s, v2.4s + fmul v13.4s, v13.4s, v2.4s + fmul v14.4s, v14.4s, v3.4s + fmul v15.4s, v15.4s, v3.4s + # Add bias. + fadd v12.4s, v12.4s, v6.4s + fadd v13.4s, v13.4s, v6.4s + fadd v14.4s, v14.4s, v7.4s + fadd v15.4s, v15.4s, v7.4s + # Min/max clamping.. + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q12, q14, [x6], 32 + stp q13, q15, [x13], 32 + sub x3, x3, x2 + sub x9, x9, x2 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q12, [x6], 16 + str q13, [x13], 16 + mov v12.16b, v14.16b + mov v13.16b, v15.16b + + +tail_2: + tbz x1, 1, tail_1 + str d12, [x6], 8 + str d13, [x13], 8 + dup d12, v12.d[1] + dup d13, v13.d[1] + + +tail_1: + tbz x1, 0, return + str s12, [x6] + str s13, [x13] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_lane \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld32.S new file mode 100644 index 000000000000..e0df35013602 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld32.S @@ -0,0 +1,235 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + ldr x24, [sp, 16] + # Round kc up to channels. + add x2, x2, #3 + and x2, x2, #0xFFFFFFFFFFFFFFFC + + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x13, x6, x7 + add x14, x13, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with k_sum * input zero point. + ldp q10, q11, [x24] + ldr q10, [x24] + ldp q2, q3, [x5, 0] + ldp q4, q5, [x5, 32] + mul v12.4s, v2.4s, v10.s[0] + mul v13.4s, v2.4s, v10.s[2] + mul v14.4s, v2.4s, v11.s[0] + mul v15.4s, v3.4s, v10.s[0] + mul v16.4s, v3.4s, v10.s[2] + mul v17.4s, v3.4s, v11.s[0] + mul v18.4s, v4.4s, v10.s[0] + mul v19.4s, v4.4s, v10.s[2] + mul v20.4s, v4.4s, v11.s[0] + mul v21.4s, v5.4s, v10.s[0] + mul v22.4s, v5.4s, v10.s[2] + mul v23.4s, v5.4s, v11.s[0] + add x5, x5, 64 + +inner_loop: + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 + ldp q6, q7, [x5], 32 + ldp q8, q9, [x5], 32 + sdot v12.4s, v6.16b, v2.4b[0] + sdot v13.4s, v6.16b, v3.4b[0] + sdot v14.4s, v6.16b, v4.4b[0] + sdot v15.4s, v7.16b, v2.4b[0] + sdot v16.4s, v7.16b, v3.4b[0] + sdot v17.4s, v7.16b, v4.4b[0] + sdot v18.4s, v8.16b, v2.4b[0] + sdot v19.4s, v8.16b, v3.4b[0] + sdot v20.4s, v8.16b, v4.4b[0] + sdot v21.4s, v9.16b, v2.4b[0] + sdot v22.4s, v9.16b, v3.4b[0] + sdot v23.4s, v9.16b, v4.4b[0] + subs x20, x20, 4 + bne inner_loop + + # Convert from int32 to float. + scvtf v12.4s, v12.4s + scvtf v13.4s, v13.4s + scvtf v14.4s, v14.4s + scvtf v15.4s, v15.4s + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + scvtf v20.4s, v20.4s + scvtf v21.4s, v21.4s + scvtf v22.4s, v22.4s + scvtf v23.4s, v23.4s + # Multiply by input scale. + fmul v12.4s, v12.4s, v10.s[1] + fmul v13.4s, v13.4s, v10.s[3] + fmul v14.4s, v14.4s, v11.s[1] + fmul v15.4s, v15.4s, v10.s[1] + fmul v16.4s, v16.4s, v10.s[3] + fmul v17.4s, v17.4s, v11.s[1] + fmul v18.4s, v18.4s, v10.s[1] + fmul v19.4s, v19.4s, v10.s[3] + fmul v20.4s, v20.4s, v11.s[1] + fmul v21.4s, v21.4s, v10.s[1] + fmul v22.4s, v22.4s, v10.s[3] + fmul v23.4s, v23.4s, v11.s[1] + # Load weights scale. + ldp q2, q3, [x5, 0] + ldp q4, q5, [x5, 32] + add x5, x5, 64 + # Load biases. + ldp q6, q7, [x5, 0] + ldp q8, q9, [x5, 32] + add x5, x5, 64 + # Multiply by weight's scale. + fmul v12.4s, v12.4s, v2.4s + fmul v13.4s, v13.4s, v2.4s + fmul v14.4s, v14.4s, v2.4s + fmul v15.4s, v15.4s, v3.4s + fmul v16.4s, v16.4s, v3.4s + fmul v17.4s, v17.4s, v3.4s + fmul v18.4s, v18.4s, v4.4s + fmul v19.4s, v19.4s, v4.4s + fmul v20.4s, v20.4s, v4.4s + fmul v21.4s, v21.4s, v5.4s + fmul v22.4s, v22.4s, v5.4s + fmul v23.4s, v23.4s, v5.4s + # Add bias. + fadd v12.4s, v12.4s, v6.4s + fadd v13.4s, v13.4s, v6.4s + fadd v14.4s, v14.4s, v6.4s + fadd v15.4s, v15.4s, v7.4s + fadd v16.4s, v16.4s, v7.4s + fadd v17.4s, v17.4s, v7.4s + fadd v18.4s, v18.4s, v8.4s + fadd v19.4s, v19.4s, v8.4s + fadd v20.4s, v20.4s, v8.4s + fadd v21.4s, v21.4s, v9.4s + fadd v22.4s, v22.4s, v9.4s + fadd v23.4s, v23.4s, v9.4s + # Min/max clamping.. + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmin v19.4s, v1.4s, v19.4s + fmin v20.4s, v1.4s, v20.4s + fmin v21.4s, v1.4s, v21.4s + fmin v22.4s, v1.4s, v22.4s + fmin v23.4s, v1.4s, v23.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + fmax v19.4s, v0.4s, v19.4s + fmax v20.4s, v0.4s, v20.4s + fmax v21.4s, v0.4s, v21.4s + fmax v22.4s, v0.4s, v22.4s + fmax v23.4s, v0.4s, v23.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q12, q15, [x6], 32 + stp q18, q21, [x6], 32 + stp q13, q16, [x13], 32 + stp q19, q22, [x13], 32 + stp q14, q17, [x14], 32 + stp q20, q23, [x14], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q12, q15, [x6], 32 + stp q13, q16, [x13], 32 + stp q14, q17, [x14], 32 + mov v12.16b, v18.16b + mov v15.16b, v21.16b + mov v13.16b, v19.16b + mov v16.16b, v22.16b + mov v14.16b, v20.16b + mov v17.16b, v23.16b + + +tail_4: + tbz x1, 2, tail_2 + str q12, [x6], 16 + str q13, [x13], 16 + str q14, [x14], 16 + mov v12.16b, v15.16b + mov v13.16b, v16.16b + mov v14.16b, v17.16b + + +tail_2: + tbz x1, 1, tail_1 + str d12, [x6], 8 + str d13, [x13], 8 + str d14, [x14], 8 + dup d12, v12.d[1] + dup d13, v13.d[1] + dup d14, v14.d[1] + + +tail_1: + tbz x1, 0, return + str s12, [x6] + str s13, [x13] + str s14, [x14] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_lane \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..ff0bedd1102f --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,147 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 656 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 744] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + # Load quantization_params pointer from stack + mov r11, [rsp + 744] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm5 + vmovups [r13], zmm12 + vmovups [rbx], zmm14 + add r10, 64 + add r13, 64 + add rbx, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + +return: + add rsp, 656 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..1231fb19cd4e --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,180 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 656 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 744] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm16, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm17, zmm7, ZMMWORD PTR [rsp + 592] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm15, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm16, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm17, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + # Load quantization_params pointer from stack + mov r11, [rsp + 744] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 4]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 12]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 20]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm11, zmm7 + vfmadd213ps zmm16, zmm11, zmm7 + vfmadd213ps zmm17, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm15 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm16 + vmovups [rbx], zmm14 + vmovups [rbx + 64], zmm17 + add r10, 128 + add r13, 128 + add rbx, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm15 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm17 + +return: + add rsp, 656 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x64-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x64-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..6518bd7b0c92 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x64-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,247 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 656 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 744] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm16, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm17, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm18, zmm8, ZMMWORD PTR [rsp + 464] + vpmulld zmm19, zmm8, ZMMWORD PTR [rsp + 528] + vpmulld zmm20, zmm8, ZMMWORD PTR [rsp + 592] + vpmulld zmm21, zmm9, ZMMWORD PTR [rsp + 464] + vpmulld zmm22, zmm9, ZMMWORD PTR [rsp + 528] + vpmulld zmm23, zmm9, ZMMWORD PTR [rsp + 592] + add r9, 256 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm15, zmm2, zmm7 + vpdpbusd zmm18, zmm2, zmm8 + vpdpbusd zmm21, zmm2, zmm9 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm16, zmm2, zmm7 + vpdpbusd zmm19, zmm2, zmm8 + vpdpbusd zmm22, zmm2, zmm9 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm17, zmm2, zmm7 + vpdpbusd zmm20, zmm2, zmm8 + vpdpbusd zmm23, zmm2, zmm9 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + # Load quantization_params pointer from stack + mov r11, [rsp + 744] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 4]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 12]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 20]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 4]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 12]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 20]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 4]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 12]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 20]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm2, [r9 + 128] + vmovaps zmm3, [r9 + 192] + add r9, 256 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm11, zmm7 + vfmadd213ps zmm16, zmm11, zmm7 + vfmadd213ps zmm17, zmm11, zmm7 + vfmadd213ps zmm18, zmm2, zmm8 + vfmadd213ps zmm19, zmm2, zmm8 + vfmadd213ps zmm20, zmm2, zmm8 + vfmadd213ps zmm21, zmm3, zmm9 + vfmadd213ps zmm22, zmm3, zmm9 + vfmadd213ps zmm23, zmm3, zmm9 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm15 + vmovups [r10 + 128], zmm18 + vmovups [r10 + 192], zmm21 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm16 + vmovups [r13 + 128], zmm19 + vmovups [r13 + 192], zmm22 + vmovups [rbx], zmm14 + vmovups [rbx + 64], zmm17 + vmovups [rbx + 128], zmm20 + vmovups [rbx + 192], zmm23 + add r10, 256 + add r13, 256 + add rbx, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm15 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm18 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm21 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm19 + vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm22 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm20 + vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm23 + +return: + add rsp, 656 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld32.S new file mode 100644 index 000000000000..dc270f1d7d74 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld32.S @@ -0,0 +1,167 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + ldr x24, [sp, 16] + # Round kc up to channels. + add x2, x2, #3 + and x2, x2, #0xFFFFFFFFFFFFFFFC + + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x13, x6, x7 + add x14, x13, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with k_sum * input zero point. + ldp q10, q11, [x24] + ldr q10, [x24] + ldp q2, q3, [x5, 0] + mul v12.4s, v2.4s, v10.s[0] + mul v13.4s, v2.4s, v10.s[2] + mul v14.4s, v2.4s, v11.s[0] + mul v15.4s, v3.4s, v10.s[0] + mul v16.4s, v3.4s, v10.s[2] + mul v17.4s, v3.4s, v11.s[0] + add x5, x5, 32 + +inner_loop: + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 + ldp q6, q7, [x5], 32 + sdot v12.4s, v6.16b, v2.4b[0] + sdot v13.4s, v6.16b, v3.4b[0] + sdot v14.4s, v6.16b, v4.4b[0] + sdot v15.4s, v7.16b, v2.4b[0] + sdot v16.4s, v7.16b, v3.4b[0] + sdot v17.4s, v7.16b, v4.4b[0] + subs x20, x20, 4 + bne inner_loop + + # Convert from int32 to float. + scvtf v12.4s, v12.4s + scvtf v13.4s, v13.4s + scvtf v14.4s, v14.4s + scvtf v15.4s, v15.4s + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + # Multiply by input scale. + fmul v12.4s, v12.4s, v10.s[1] + fmul v13.4s, v13.4s, v10.s[3] + fmul v14.4s, v14.4s, v11.s[1] + fmul v15.4s, v15.4s, v10.s[1] + fmul v16.4s, v16.4s, v10.s[3] + fmul v17.4s, v17.4s, v11.s[1] + # Load weights scale. + ldp q2, q3, [x5, 0] + add x5, x5, 32 + # Load biases. + ldp q6, q7, [x5, 0] + add x5, x5, 32 + # Multiply by weight's scale. + fmul v12.4s, v12.4s, v2.4s + fmul v13.4s, v13.4s, v2.4s + fmul v14.4s, v14.4s, v2.4s + fmul v15.4s, v15.4s, v3.4s + fmul v16.4s, v16.4s, v3.4s + fmul v17.4s, v17.4s, v3.4s + # Add bias. + fadd v12.4s, v12.4s, v6.4s + fadd v13.4s, v13.4s, v6.4s + fadd v14.4s, v14.4s, v6.4s + fadd v15.4s, v15.4s, v7.4s + fadd v16.4s, v16.4s, v7.4s + fadd v17.4s, v17.4s, v7.4s + # Min/max clamping.. + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q12, q15, [x6], 32 + stp q13, q16, [x13], 32 + stp q14, q17, [x14], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q12, [x6], 16 + str q13, [x13], 16 + str q14, [x14], 16 + mov v12.16b, v15.16b + mov v13.16b, v16.16b + mov v14.16b, v17.16b + + +tail_2: + tbz x1, 1, tail_1 + str d12, [x6], 8 + str d13, [x13], 8 + str d14, [x14], 8 + dup d12, v12.d[1] + dup d13, v13.d[1] + dup d14, v14.d[1] + + +tail_1: + tbz x1, 0, return + str s12, [x6] + str s13, [x13] + str s14, [x14] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_lane \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld32.S new file mode 100644 index 000000000000..6150aafa1556 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld32.S @@ -0,0 +1,284 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + ldr x24, [sp, 16] + # Round kc up to channels. + add x2, x2, #3 + and x2, x2, #0xFFFFFFFFFFFFFFFC + + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x11, x10, x4 + add x13, x6, x7 + add x14, x13, x7 + add x15, x14, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + + cmp x0, 4 + csel x11, x10, x11, LO + csel x15, x14, x15, LO + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with k_sum * input zero point. + ldp q10, q11, [x24] + ldp q2, q3, [x5, 0] + ldp q4, q5, [x5, 32] + mul v12.4s, v2.4s, v10.s[0] + mul v13.4s, v2.4s, v10.s[2] + mul v14.4s, v2.4s, v11.s[0] + mul v15.4s, v2.4s, v11.s[2] + mul v16.4s, v3.4s, v10.s[0] + mul v17.4s, v3.4s, v10.s[2] + mul v18.4s, v3.4s, v11.s[0] + mul v19.4s, v3.4s, v11.s[2] + mul v20.4s, v4.4s, v10.s[0] + mul v21.4s, v4.4s, v10.s[2] + mul v22.4s, v4.4s, v11.s[0] + mul v23.4s, v4.4s, v11.s[2] + mul v24.4s, v5.4s, v10.s[0] + mul v25.4s, v5.4s, v10.s[2] + mul v26.4s, v5.4s, v11.s[0] + mul v27.4s, v5.4s, v11.s[2] + add x5, x5, 64 + +inner_loop: + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 + ldr s5, [x11], 4 + ldp q6, q7, [x5], 32 + ldp q8, q9, [x5], 32 + sdot v12.4s, v6.16b, v2.4b[0] + sdot v13.4s, v6.16b, v3.4b[0] + sdot v14.4s, v6.16b, v4.4b[0] + sdot v15.4s, v6.16b, v5.4b[0] + sdot v16.4s, v7.16b, v2.4b[0] + sdot v17.4s, v7.16b, v3.4b[0] + sdot v18.4s, v7.16b, v4.4b[0] + sdot v19.4s, v7.16b, v5.4b[0] + sdot v20.4s, v8.16b, v2.4b[0] + sdot v21.4s, v8.16b, v3.4b[0] + sdot v22.4s, v8.16b, v4.4b[0] + sdot v23.4s, v8.16b, v5.4b[0] + sdot v24.4s, v9.16b, v2.4b[0] + sdot v25.4s, v9.16b, v3.4b[0] + sdot v26.4s, v9.16b, v4.4b[0] + sdot v27.4s, v9.16b, v5.4b[0] + subs x20, x20, 4 + bne inner_loop + + # Convert from int32 to float. + scvtf v12.4s, v12.4s + scvtf v13.4s, v13.4s + scvtf v14.4s, v14.4s + scvtf v15.4s, v15.4s + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + scvtf v20.4s, v20.4s + scvtf v21.4s, v21.4s + scvtf v22.4s, v22.4s + scvtf v23.4s, v23.4s + scvtf v24.4s, v24.4s + scvtf v25.4s, v25.4s + scvtf v26.4s, v26.4s + scvtf v27.4s, v27.4s + # Multiply by input scale. + fmul v12.4s, v12.4s, v10.s[1] + fmul v13.4s, v13.4s, v10.s[3] + fmul v14.4s, v14.4s, v11.s[1] + fmul v15.4s, v15.4s, v11.s[3] + fmul v16.4s, v16.4s, v10.s[1] + fmul v17.4s, v17.4s, v10.s[3] + fmul v18.4s, v18.4s, v11.s[1] + fmul v19.4s, v19.4s, v11.s[3] + fmul v20.4s, v20.4s, v10.s[1] + fmul v21.4s, v21.4s, v10.s[3] + fmul v22.4s, v22.4s, v11.s[1] + fmul v23.4s, v23.4s, v11.s[3] + fmul v24.4s, v24.4s, v10.s[1] + fmul v25.4s, v25.4s, v10.s[3] + fmul v26.4s, v26.4s, v11.s[1] + fmul v27.4s, v27.4s, v11.s[3] + # Load weights scale. + ldp q2, q3, [x5, 0] + ldp q4, q5, [x5, 32] + add x5, x5, 64 + # Load biases. + ldp q6, q7, [x5, 0] + ldp q8, q9, [x5, 32] + add x5, x5, 64 + # Multiply by weight's scale. + fmul v12.4s, v12.4s, v2.4s + fmul v13.4s, v13.4s, v2.4s + fmul v14.4s, v14.4s, v2.4s + fmul v15.4s, v15.4s, v2.4s + fmul v16.4s, v16.4s, v3.4s + fmul v17.4s, v17.4s, v3.4s + fmul v18.4s, v18.4s, v3.4s + fmul v19.4s, v19.4s, v3.4s + fmul v20.4s, v20.4s, v4.4s + fmul v21.4s, v21.4s, v4.4s + fmul v22.4s, v22.4s, v4.4s + fmul v23.4s, v23.4s, v4.4s + fmul v24.4s, v24.4s, v5.4s + fmul v25.4s, v25.4s, v5.4s + fmul v26.4s, v26.4s, v5.4s + fmul v27.4s, v27.4s, v5.4s + # Add bias. + fadd v12.4s, v12.4s, v6.4s + fadd v13.4s, v13.4s, v6.4s + fadd v14.4s, v14.4s, v6.4s + fadd v15.4s, v15.4s, v6.4s + fadd v16.4s, v16.4s, v7.4s + fadd v17.4s, v17.4s, v7.4s + fadd v18.4s, v18.4s, v7.4s + fadd v19.4s, v19.4s, v7.4s + fadd v20.4s, v20.4s, v8.4s + fadd v21.4s, v21.4s, v8.4s + fadd v22.4s, v22.4s, v8.4s + fadd v23.4s, v23.4s, v8.4s + fadd v24.4s, v24.4s, v9.4s + fadd v25.4s, v25.4s, v9.4s + fadd v26.4s, v26.4s, v9.4s + fadd v27.4s, v27.4s, v9.4s + # Min/max clamping.. + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmin v19.4s, v1.4s, v19.4s + fmin v20.4s, v1.4s, v20.4s + fmin v21.4s, v1.4s, v21.4s + fmin v22.4s, v1.4s, v22.4s + fmin v23.4s, v1.4s, v23.4s + fmin v24.4s, v1.4s, v24.4s + fmin v25.4s, v1.4s, v25.4s + fmin v26.4s, v1.4s, v26.4s + fmin v27.4s, v1.4s, v27.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + fmax v19.4s, v0.4s, v19.4s + fmax v20.4s, v0.4s, v20.4s + fmax v21.4s, v0.4s, v21.4s + fmax v22.4s, v0.4s, v22.4s + fmax v23.4s, v0.4s, v23.4s + fmax v24.4s, v0.4s, v24.4s + fmax v25.4s, v0.4s, v25.4s + fmax v26.4s, v0.4s, v26.4s + fmax v27.4s, v0.4s, v27.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q12, q16, [x6], 32 + stp q20, q24, [x6], 32 + stp q13, q17, [x13], 32 + stp q21, q25, [x13], 32 + stp q14, q18, [x14], 32 + stp q22, q26, [x14], 32 + stp q15, q19, [x15], 32 + stp q23, q27, [x15], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 + sub x11, x11, x2 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q12, q16, [x6], 32 + stp q13, q17, [x13], 32 + stp q14, q18, [x14], 32 + stp q15, q19, [x15], 32 + mov v12.16b, v20.16b + mov v16.16b, v24.16b + mov v13.16b, v21.16b + mov v17.16b, v25.16b + mov v14.16b, v22.16b + mov v18.16b, v26.16b + mov v15.16b, v23.16b + mov v19.16b, v27.16b + + +tail_4: + tbz x1, 2, tail_2 + str q12, [x6], 16 + str q13, [x13], 16 + str q14, [x14], 16 + str q15, [x15], 16 + mov v12.16b, v16.16b + mov v13.16b, v17.16b + mov v14.16b, v18.16b + mov v15.16b, v19.16b + + +tail_2: + tbz x1, 1, tail_1 + str d12, [x6], 8 + str d13, [x13], 8 + str d14, [x14], 8 + str d15, [x15], 8 + dup d12, v12.d[1] + dup d13, v13.d[1] + dup d14, v14.d[1] + dup d15, v15.d[1] + + +tail_1: + tbz x1, 0, return + str s12, [x6] + str s13, [x13] + str s14, [x14] + str s15, [x15] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_lane \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..088e175bf915 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,170 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 720 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Load quantization params pointer from stack + mov r11, [rsp + 808] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + # Load quantization_params pointer from stack + mov r11, [rsp + 808] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm5 + vmovups [r13], zmm12 + vmovups [rbx], zmm14 + vmovups [rbp], zmm15 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + vmovups ZMMWORD PTR [rbp]{k1}, zmm15 + +return: + add rsp, 720 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..eed325b907c7 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,212 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 720 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Load quantization params pointer from stack + mov r11, [rsp + 808] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm17, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm18, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm19, zmm7, ZMMWORD PTR [rsp + 656] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm16, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm17, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm18, zmm2, zmm7 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm19, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + # Load quantization_params pointer from stack + mov r11, [rsp + 808] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 4]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 12]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 20]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 28]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm11, zmm7 + vfmadd213ps zmm17, zmm11, zmm7 + vfmadd213ps zmm18, zmm11, zmm7 + vfmadd213ps zmm19, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm16 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm17 + vmovups [rbx], zmm14 + vmovups [rbx + 64], zmm18 + vmovups [rbp], zmm15 + vmovups [rbp + 64], zmm19 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rbp]{k1}, zmm15 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm19 + +return: + add rsp, 720 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..709c05c9e3d9 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,297 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 720 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Load quantization params pointer from stack + mov r11, [rsp + 808] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm17, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm18, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm19, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm20, zmm8, ZMMWORD PTR [rsp + 464] + vpmulld zmm21, zmm8, ZMMWORD PTR [rsp + 528] + vpmulld zmm22, zmm8, ZMMWORD PTR [rsp + 592] + vpmulld zmm23, zmm8, ZMMWORD PTR [rsp + 656] + vpmulld zmm24, zmm9, ZMMWORD PTR [rsp + 464] + vpmulld zmm25, zmm9, ZMMWORD PTR [rsp + 528] + vpmulld zmm26, zmm9, ZMMWORD PTR [rsp + 592] + vpmulld zmm27, zmm9, ZMMWORD PTR [rsp + 656] + add r9, 256 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm16, zmm2, zmm7 + vpdpbusd zmm20, zmm2, zmm8 + vpdpbusd zmm24, zmm2, zmm9 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm17, zmm2, zmm7 + vpdpbusd zmm21, zmm2, zmm8 + vpdpbusd zmm25, zmm2, zmm9 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm18, zmm2, zmm7 + vpdpbusd zmm22, zmm2, zmm8 + vpdpbusd zmm26, zmm2, zmm9 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm19, zmm2, zmm7 + vpdpbusd zmm23, zmm2, zmm8 + vpdpbusd zmm27, zmm2, zmm9 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + vcvtdq2ps zmm24, zmm24 + vcvtdq2ps zmm25, zmm25 + vcvtdq2ps zmm26, zmm26 + vcvtdq2ps zmm27, zmm27 + # Load quantization_params pointer from stack + mov r11, [rsp + 808] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 4]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 12]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 20]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 28]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 4]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 12]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 20]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 28]{1to16} + vmulps zmm24, zmm24, DWORD PTR [r11 + 4]{1to16} + vmulps zmm25, zmm25, DWORD PTR [r11 + 12]{1to16} + vmulps zmm26, zmm26, DWORD PTR [r11 + 20]{1to16} + vmulps zmm27, zmm27, DWORD PTR [r11 + 28]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm2, [r9 + 128] + vmovaps zmm3, [r9 + 192] + add r9, 256 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm11, zmm7 + vfmadd213ps zmm17, zmm11, zmm7 + vfmadd213ps zmm18, zmm11, zmm7 + vfmadd213ps zmm19, zmm11, zmm7 + vfmadd213ps zmm20, zmm2, zmm8 + vfmadd213ps zmm21, zmm2, zmm8 + vfmadd213ps zmm22, zmm2, zmm8 + vfmadd213ps zmm23, zmm2, zmm8 + vfmadd213ps zmm24, zmm3, zmm9 + vfmadd213ps zmm25, zmm3, zmm9 + vfmadd213ps zmm26, zmm3, zmm9 + vfmadd213ps zmm27, zmm3, zmm9 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm16 + vmovups [r10 + 128], zmm20 + vmovups [r10 + 192], zmm24 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm17 + vmovups [r13 + 128], zmm21 + vmovups [r13 + 192], zmm25 + vmovups [rbx], zmm14 + vmovups [rbx + 64], zmm18 + vmovups [rbx + 128], zmm22 + vmovups [rbx + 192], zmm26 + vmovups [rbp], zmm15 + vmovups [rbp + 64], zmm19 + vmovups [rbp + 128], zmm23 + vmovups [rbp + 192], zmm27 + add r10, 256 + add r13, 256 + add rbx, 256 + add rbp, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm20 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm24 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm21 + vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm25 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm22 + vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm26 + vmovups ZMMWORD PTR [rbp]{k1}, zmm15 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [rbp + 128]{k3}, zmm23 + vmovups ZMMWORD PTR [rbp + 192]{k4}, zmm27 + +return: + add rsp, 720 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld32.S new file mode 100644 index 000000000000..4150079a1668 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld32.S @@ -0,0 +1,196 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + ldr x24, [sp, 16] + # Round kc up to channels. + add x2, x2, #3 + and x2, x2, #0xFFFFFFFFFFFFFFFC + + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x11, x10, x4 + add x13, x6, x7 + add x14, x13, x7 + add x15, x14, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + + cmp x0, 4 + csel x11, x10, x11, LO + csel x15, x14, x15, LO + +outer_loop: + # Initialize k counter. + mov x20, x2 + # Initialize accumulators with k_sum * input zero point. + ldp q10, q11, [x24] + ldp q2, q3, [x5, 0] + mul v12.4s, v2.4s, v10.s[0] + mul v13.4s, v2.4s, v10.s[2] + mul v14.4s, v2.4s, v11.s[0] + mul v15.4s, v2.4s, v11.s[2] + mul v16.4s, v3.4s, v10.s[0] + mul v17.4s, v3.4s, v10.s[2] + mul v18.4s, v3.4s, v11.s[0] + mul v19.4s, v3.4s, v11.s[2] + add x5, x5, 32 + +inner_loop: + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 + ldr s5, [x11], 4 + ldp q6, q7, [x5], 32 + sdot v12.4s, v6.16b, v2.4b[0] + sdot v13.4s, v6.16b, v3.4b[0] + sdot v14.4s, v6.16b, v4.4b[0] + sdot v15.4s, v6.16b, v5.4b[0] + sdot v16.4s, v7.16b, v2.4b[0] + sdot v17.4s, v7.16b, v3.4b[0] + sdot v18.4s, v7.16b, v4.4b[0] + sdot v19.4s, v7.16b, v5.4b[0] + subs x20, x20, 4 + bne inner_loop + + # Convert from int32 to float. + scvtf v12.4s, v12.4s + scvtf v13.4s, v13.4s + scvtf v14.4s, v14.4s + scvtf v15.4s, v15.4s + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + # Multiply by input scale. + fmul v12.4s, v12.4s, v10.s[1] + fmul v13.4s, v13.4s, v10.s[3] + fmul v14.4s, v14.4s, v11.s[1] + fmul v15.4s, v15.4s, v11.s[3] + fmul v16.4s, v16.4s, v10.s[1] + fmul v17.4s, v17.4s, v10.s[3] + fmul v18.4s, v18.4s, v11.s[1] + fmul v19.4s, v19.4s, v11.s[3] + # Load weights scale. + ldp q2, q3, [x5, 0] + add x5, x5, 32 + # Load biases. + ldp q6, q7, [x5, 0] + add x5, x5, 32 + # Multiply by weight's scale. + fmul v12.4s, v12.4s, v2.4s + fmul v13.4s, v13.4s, v2.4s + fmul v14.4s, v14.4s, v2.4s + fmul v15.4s, v15.4s, v2.4s + fmul v16.4s, v16.4s, v3.4s + fmul v17.4s, v17.4s, v3.4s + fmul v18.4s, v18.4s, v3.4s + fmul v19.4s, v19.4s, v3.4s + # Add bias. + fadd v12.4s, v12.4s, v6.4s + fadd v13.4s, v13.4s, v6.4s + fadd v14.4s, v14.4s, v6.4s + fadd v15.4s, v15.4s, v6.4s + fadd v16.4s, v16.4s, v7.4s + fadd v17.4s, v17.4s, v7.4s + fadd v18.4s, v18.4s, v7.4s + fadd v19.4s, v19.4s, v7.4s + # Min/max clamping.. + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmin v19.4s, v1.4s, v19.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + fmax v19.4s, v0.4s, v19.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q12, q16, [x6], 32 + stp q13, q17, [x13], 32 + stp q14, q18, [x14], 32 + stp q15, q19, [x15], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 + sub x11, x11, x2 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q12, [x6], 16 + str q13, [x13], 16 + str q14, [x14], 16 + str q15, [x15], 16 + mov v12.16b, v16.16b + mov v13.16b, v17.16b + mov v14.16b, v18.16b + mov v15.16b, v19.16b + + +tail_2: + tbz x1, 1, tail_1 + str d12, [x6], 8 + str d13, [x13], 8 + str d14, [x14], 8 + str d15, [x15], 8 + dup d12, v12.d[1] + dup d13, v13.d[1] + dup d14, v14.d[1] + dup d15, v15.d[1] + + +tail_1: + tbz x1, 0, return + str s12, [x6] + str s13, [x13] + str s14, [x14] + str s15, [x15] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_lane \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..87080cb6b9c0 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,193 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 784 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Clamp a & c pointers if mr <= 4 + mov r12, r14 + add r12, r8 + mov r8, rbp + add r8, r11 + cmp rdi, 4 + cmovle r12, r14 + cmovle r8, rbp + + # Load quantization params pointer from stack + mov r11, [rsp + 872] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + # Load quantization_params pointer from stack + mov r11, [rsp + 872] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm5 + vmovups [r13], zmm12 + vmovups [rbx], zmm14 + vmovups [rbp], zmm15 + vmovups [r8], zmm16 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + add r8, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + vmovups ZMMWORD PTR [rbp]{k1}, zmm15 + vmovups ZMMWORD PTR [r8]{k1}, zmm16 + +return: + add rsp, 784 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..1fe64f5ee563 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,244 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 784 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Clamp a & c pointers if mr <= 4 + mov r12, r14 + add r12, r8 + mov r8, rbp + add r8, r11 + cmp rdi, 4 + cmovle r12, r14 + cmovle r8, rbp + + # Load quantization params pointer from stack + mov r11, [rsp + 872] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm18, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm19, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm20, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm21, zmm7, ZMMWORD PTR [rsp + 720] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm17, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm18, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm19, zmm2, zmm7 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm20, zmm2, zmm7 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpdpbusd zmm21, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + # Load quantization_params pointer from stack + mov r11, [rsp + 872] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 4]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 12]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 20]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 28]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 36]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm11, zmm7 + vfmadd213ps zmm18, zmm11, zmm7 + vfmadd213ps zmm19, zmm11, zmm7 + vfmadd213ps zmm20, zmm11, zmm7 + vfmadd213ps zmm21, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm17 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm18 + vmovups [rbx], zmm14 + vmovups [rbx + 64], zmm19 + vmovups [rbp], zmm15 + vmovups [rbp + 64], zmm20 + vmovups [r8], zmm16 + vmovups [r8 + 64], zmm21 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + add r8, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [rbp]{k1}, zmm15 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [r8]{k1}, zmm16 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm21 + +return: + add rsp, 784 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..0744f7740160 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,347 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 784 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Clamp a & c pointers if mr <= 4 + mov r12, r14 + add r12, r8 + mov r8, rbp + add r8, r11 + cmp rdi, 4 + cmovle r12, r14 + cmovle r8, rbp + + # Load quantization params pointer from stack + mov r11, [rsp + 872] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm18, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm19, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm20, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm21, zmm7, ZMMWORD PTR [rsp + 720] + vpmulld zmm22, zmm8, ZMMWORD PTR [rsp + 464] + vpmulld zmm23, zmm8, ZMMWORD PTR [rsp + 528] + vpmulld zmm24, zmm8, ZMMWORD PTR [rsp + 592] + vpmulld zmm25, zmm8, ZMMWORD PTR [rsp + 656] + vpmulld zmm26, zmm8, ZMMWORD PTR [rsp + 720] + vpmulld zmm27, zmm9, ZMMWORD PTR [rsp + 464] + vpmulld zmm28, zmm9, ZMMWORD PTR [rsp + 528] + vpmulld zmm29, zmm9, ZMMWORD PTR [rsp + 592] + vpmulld zmm30, zmm9, ZMMWORD PTR [rsp + 656] + vpmulld zmm4, zmm9, ZMMWORD PTR [rsp + 720] + add r9, 256 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm17, zmm2, zmm7 + vpdpbusd zmm22, zmm2, zmm8 + vpdpbusd zmm27, zmm2, zmm9 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm18, zmm2, zmm7 + vpdpbusd zmm23, zmm2, zmm8 + vpdpbusd zmm28, zmm2, zmm9 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm19, zmm2, zmm7 + vpdpbusd zmm24, zmm2, zmm8 + vpdpbusd zmm29, zmm2, zmm9 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm20, zmm2, zmm7 + vpdpbusd zmm25, zmm2, zmm8 + vpdpbusd zmm30, zmm2, zmm9 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpdpbusd zmm21, zmm2, zmm7 + vpdpbusd zmm26, zmm2, zmm8 + vpdpbusd zmm4, zmm2, zmm9 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + vcvtdq2ps zmm24, zmm24 + vcvtdq2ps zmm25, zmm25 + vcvtdq2ps zmm26, zmm26 + vcvtdq2ps zmm27, zmm27 + vcvtdq2ps zmm28, zmm28 + vcvtdq2ps zmm29, zmm29 + vcvtdq2ps zmm30, zmm30 + vcvtdq2ps zmm4, zmm4 + # Load quantization_params pointer from stack + mov r11, [rsp + 872] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 4]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 12]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 20]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 28]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 36]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 4]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 12]{1to16} + vmulps zmm24, zmm24, DWORD PTR [r11 + 20]{1to16} + vmulps zmm25, zmm25, DWORD PTR [r11 + 28]{1to16} + vmulps zmm26, zmm26, DWORD PTR [r11 + 36]{1to16} + vmulps zmm27, zmm27, DWORD PTR [r11 + 4]{1to16} + vmulps zmm28, zmm28, DWORD PTR [r11 + 12]{1to16} + vmulps zmm29, zmm29, DWORD PTR [r11 + 20]{1to16} + vmulps zmm30, zmm30, DWORD PTR [r11 + 28]{1to16} + vmulps zmm4, zmm4, DWORD PTR [r11 + 36]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm2, [r9 + 128] + vmovaps zmm3, [r9 + 192] + add r9, 256 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm11, zmm7 + vfmadd213ps zmm18, zmm11, zmm7 + vfmadd213ps zmm19, zmm11, zmm7 + vfmadd213ps zmm20, zmm11, zmm7 + vfmadd213ps zmm21, zmm11, zmm7 + vfmadd213ps zmm22, zmm2, zmm8 + vfmadd213ps zmm23, zmm2, zmm8 + vfmadd213ps zmm24, zmm2, zmm8 + vfmadd213ps zmm25, zmm2, zmm8 + vfmadd213ps zmm26, zmm2, zmm8 + vfmadd213ps zmm27, zmm3, zmm9 + vfmadd213ps zmm28, zmm3, zmm9 + vfmadd213ps zmm29, zmm3, zmm9 + vfmadd213ps zmm30, zmm3, zmm9 + vfmadd213ps zmm4, zmm3, zmm9 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vminps zmm28, zmm1, zmm28 + vminps zmm29, zmm1, zmm29 + vminps zmm30, zmm1, zmm30 + vminps zmm4, zmm1, zmm4 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + vmaxps zmm28, zmm0, zmm28 + vmaxps zmm29, zmm0, zmm29 + vmaxps zmm30, zmm0, zmm30 + vmaxps zmm4, zmm0, zmm4 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm17 + vmovups [r10 + 128], zmm22 + vmovups [r10 + 192], zmm27 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm18 + vmovups [r13 + 128], zmm23 + vmovups [r13 + 192], zmm28 + vmovups [rbx], zmm14 + vmovups [rbx + 64], zmm19 + vmovups [rbx + 128], zmm24 + vmovups [rbx + 192], zmm29 + vmovups [rbp], zmm15 + vmovups [rbp + 64], zmm20 + vmovups [rbp + 128], zmm25 + vmovups [rbp + 192], zmm30 + vmovups [r8], zmm16 + vmovups [r8 + 64], zmm21 + vmovups [r8 + 128], zmm26 + vmovups [r8 + 192], zmm4 + add r10, 256 + add r13, 256 + add rbx, 256 + add rbp, 256 + add r8, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm22 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm27 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm23 + vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm28 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm24 + vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm29 + vmovups ZMMWORD PTR [rbp]{k1}, zmm15 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [rbp + 128]{k3}, zmm25 + vmovups ZMMWORD PTR [rbp + 192]{k4}, zmm30 + vmovups ZMMWORD PTR [r8]{k1}, zmm16 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [r8 + 128]{k3}, zmm26 + vmovups ZMMWORD PTR [r8 + 192]{k4}, zmm4 + +return: + add rsp, 784 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..161931d21684 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,259 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 848 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 936] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + # Load quantization_params pointer from stack + mov r11, [rsp + 936] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm5 + vmovups [rax], zmm12 + vmovups [r15], zmm14 + vmovups [r14], zmm15 + vmovups [r12], zmm16 + vmovups [r10], zmm17 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + +return: + add rsp, 848 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..392f5e81e14c --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,319 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 848 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 936] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm19, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm20, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm21, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm22, zmm7, ZMMWORD PTR [rsp + 720] + vpmulld zmm23, zmm7, ZMMWORD PTR [rsp + 784] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm18, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm19, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm20, zmm2, zmm7 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm21, zmm2, zmm7 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpdpbusd zmm22, zmm2, zmm7 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpdpbusd zmm23, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + # Load quantization_params pointer from stack + mov r11, [rsp + 936] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 4]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 12]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 20]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 28]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 36]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 44]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm11, zmm7 + vfmadd213ps zmm19, zmm11, zmm7 + vfmadd213ps zmm20, zmm11, zmm7 + vfmadd213ps zmm21, zmm11, zmm7 + vfmadd213ps zmm22, zmm11, zmm7 + vfmadd213ps zmm23, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm5 + vmovups [rsi + 64], zmm18 + vmovups [rax], zmm12 + vmovups [rax + 64], zmm19 + vmovups [r15], zmm14 + vmovups [r15 + 64], zmm20 + vmovups [r14], zmm15 + vmovups [r14 + 64], zmm21 + vmovups [r12], zmm16 + vmovups [r12 + 64], zmm22 + vmovups [r10], zmm17 + vmovups [r10 + 64], zmm23 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm23 + +return: + add rsp, 848 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..522dd0d1ab46 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,288 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 912 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 1000] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + # Load quantization_params pointer from stack + mov r11, [rsp + 1000] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm5 + vmovups [rax], zmm12 + vmovups [r15], zmm14 + vmovups [r14], zmm15 + vmovups [r12], zmm16 + vmovups [r10], zmm17 + vmovups [r13], zmm18 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + +return: + add rsp, 912 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..dc51096698d7 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,357 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 912 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 1000] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm20, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm21, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm22, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm23, zmm7, ZMMWORD PTR [rsp + 720] + vpmulld zmm24, zmm7, ZMMWORD PTR [rsp + 784] + vpmulld zmm25, zmm7, ZMMWORD PTR [rsp + 848] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm19, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm20, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm21, zmm2, zmm7 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm22, zmm2, zmm7 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpdpbusd zmm23, zmm2, zmm7 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpdpbusd zmm24, zmm2, zmm7 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpdpbusd zmm25, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + vcvtdq2ps zmm24, zmm24 + vcvtdq2ps zmm25, zmm25 + # Load quantization_params pointer from stack + mov r11, [rsp + 1000] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 4]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 12]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 20]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 28]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 36]{1to16} + vmulps zmm24, zmm24, DWORD PTR [r11 + 44]{1to16} + vmulps zmm25, zmm25, DWORD PTR [r11 + 52]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm11, zmm7 + vfmadd213ps zmm20, zmm11, zmm7 + vfmadd213ps zmm21, zmm11, zmm7 + vfmadd213ps zmm22, zmm11, zmm7 + vfmadd213ps zmm23, zmm11, zmm7 + vfmadd213ps zmm24, zmm11, zmm7 + vfmadd213ps zmm25, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm5 + vmovups [rsi + 64], zmm19 + vmovups [rax], zmm12 + vmovups [rax + 64], zmm20 + vmovups [r15], zmm14 + vmovups [r15 + 64], zmm21 + vmovups [r14], zmm15 + vmovups [r14 + 64], zmm22 + vmovups [r12], zmm16 + vmovups [r12 + 64], zmm23 + vmovups [r10], zmm17 + vmovups [r10 + 64], zmm24 + vmovups [r13], zmm18 + vmovups [r13 + 64], zmm25 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm25 + +return: + add rsp, 912 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..0f27cae0e2ef --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,317 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 976 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 1064] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + mov edi, [r11 + 56] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 912], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm6, ZMMWORD PTR [rsp + 912] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpbroadcastd zmm2, [rbx + r11] + vpdpbusd zmm19, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + # Load quantization_params pointer from stack + mov r11, [rsp + 1064] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 60]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm5 + vmovups [rax], zmm12 + vmovups [r15], zmm14 + vmovups [r14], zmm15 + vmovups [r12], zmm16 + vmovups [r10], zmm17 + vmovups [r13], zmm18 + vmovups [rbx], zmm19 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [rbx]{k1}, zmm19 + +return: + add rsp, 976 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..80eea889ec2c --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,395 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 976 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 1064] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + mov edi, [r11 + 56] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 912], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm6, ZMMWORD PTR [rsp + 912] + vpmulld zmm20, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm21, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm22, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm23, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm24, zmm7, ZMMWORD PTR [rsp + 720] + vpmulld zmm25, zmm7, ZMMWORD PTR [rsp + 784] + vpmulld zmm26, zmm7, ZMMWORD PTR [rsp + 848] + vpmulld zmm27, zmm7, ZMMWORD PTR [rsp + 912] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm20, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm21, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm22, zmm2, zmm7 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm23, zmm2, zmm7 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpdpbusd zmm24, zmm2, zmm7 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpdpbusd zmm25, zmm2, zmm7 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpdpbusd zmm26, zmm2, zmm7 + vpbroadcastd zmm2, [rbx + r11] + vpdpbusd zmm19, zmm2, zmm6 + vpdpbusd zmm27, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + vcvtdq2ps zmm24, zmm24 + vcvtdq2ps zmm25, zmm25 + vcvtdq2ps zmm26, zmm26 + vcvtdq2ps zmm27, zmm27 + # Load quantization_params pointer from stack + mov r11, [rsp + 1064] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 60]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 4]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 12]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 20]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 28]{1to16} + vmulps zmm24, zmm24, DWORD PTR [r11 + 36]{1to16} + vmulps zmm25, zmm25, DWORD PTR [r11 + 44]{1to16} + vmulps zmm26, zmm26, DWORD PTR [r11 + 52]{1to16} + vmulps zmm27, zmm27, DWORD PTR [r11 + 60]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm10, zmm6 + vfmadd213ps zmm20, zmm11, zmm7 + vfmadd213ps zmm21, zmm11, zmm7 + vfmadd213ps zmm22, zmm11, zmm7 + vfmadd213ps zmm23, zmm11, zmm7 + vfmadd213ps zmm24, zmm11, zmm7 + vfmadd213ps zmm25, zmm11, zmm7 + vfmadd213ps zmm26, zmm11, zmm7 + vfmadd213ps zmm27, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm5 + vmovups [rsi + 64], zmm20 + vmovups [rax], zmm12 + vmovups [rax + 64], zmm21 + vmovups [r15], zmm14 + vmovups [r15 + 64], zmm22 + vmovups [r14], zmm15 + vmovups [r14 + 64], zmm23 + vmovups [r12], zmm16 + vmovups [r12 + 64], zmm24 + vmovups [r10], zmm17 + vmovups [r10 + 64], zmm25 + vmovups [r13], zmm18 + vmovups [r13 + 64], zmm26 + vmovups [rbx], zmm19 + vmovups [rbx + 64], zmm27 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + add rbx, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm25 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm26 + vmovups ZMMWORD PTR [rbx]{k1}, zmm19 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm27 + +return: + add rsp, 976 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..023f5837e7ef --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,346 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 1040 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 1128] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + mov edi, [r11 + 56] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 912], zmm6 + mov edi, [r11 + 64] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 976], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm6, ZMMWORD PTR [rsp + 912] + vpmulld zmm20, zmm6, ZMMWORD PTR [rsp + 976] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpbroadcastd zmm2, [rbx + r11] + vpdpbusd zmm19, zmm2, zmm6 + vpbroadcastd zmm2, [rbp + r11] + vpdpbusd zmm20, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + # Load quantization_params pointer from stack + mov r11, [rsp + 1128] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 60]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 68]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm10, zmm6 + vfmadd213ps zmm20, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm5 + vmovups [rax], zmm12 + vmovups [r15], zmm14 + vmovups [r14], zmm15 + vmovups [r12], zmm16 + vmovups [r10], zmm17 + vmovups [r13], zmm18 + vmovups [rbx], zmm19 + vmovups [rbp], zmm20 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [rbx]{k1}, zmm19 + vmovups ZMMWORD PTR [rbp]{k1}, zmm20 + +return: + add rsp, 1040 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..217f56b415e5 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,433 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 1040 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 1128] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + mov edi, [r11 + 56] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 912], zmm6 + mov edi, [r11 + 64] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 976], zmm6 + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm6, ZMMWORD PTR [rsp + 912] + vpmulld zmm20, zmm6, ZMMWORD PTR [rsp + 976] + vpmulld zmm21, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm22, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm23, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm24, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm25, zmm7, ZMMWORD PTR [rsp + 720] + vpmulld zmm26, zmm7, ZMMWORD PTR [rsp + 784] + vpmulld zmm27, zmm7, ZMMWORD PTR [rsp + 848] + vpmulld zmm28, zmm7, ZMMWORD PTR [rsp + 912] + vpmulld zmm29, zmm7, ZMMWORD PTR [rsp + 976] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm21, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm22, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm23, zmm2, zmm7 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm24, zmm2, zmm7 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpdpbusd zmm25, zmm2, zmm7 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpdpbusd zmm26, zmm2, zmm7 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpdpbusd zmm27, zmm2, zmm7 + vpbroadcastd zmm2, [rbx + r11] + vpdpbusd zmm19, zmm2, zmm6 + vpdpbusd zmm28, zmm2, zmm7 + vpbroadcastd zmm2, [rbp + r11] + vpdpbusd zmm20, zmm2, zmm6 + vpdpbusd zmm29, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + vcvtdq2ps zmm24, zmm24 + vcvtdq2ps zmm25, zmm25 + vcvtdq2ps zmm26, zmm26 + vcvtdq2ps zmm27, zmm27 + vcvtdq2ps zmm28, zmm28 + vcvtdq2ps zmm29, zmm29 + # Load quantization_params pointer from stack + mov r11, [rsp + 1128] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 60]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 68]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 4]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 12]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 20]{1to16} + vmulps zmm24, zmm24, DWORD PTR [r11 + 28]{1to16} + vmulps zmm25, zmm25, DWORD PTR [r11 + 36]{1to16} + vmulps zmm26, zmm26, DWORD PTR [r11 + 44]{1to16} + vmulps zmm27, zmm27, DWORD PTR [r11 + 52]{1to16} + vmulps zmm28, zmm28, DWORD PTR [r11 + 60]{1to16} + vmulps zmm29, zmm29, DWORD PTR [r11 + 68]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm10, zmm6 + vfmadd213ps zmm20, zmm10, zmm6 + vfmadd213ps zmm21, zmm11, zmm7 + vfmadd213ps zmm22, zmm11, zmm7 + vfmadd213ps zmm23, zmm11, zmm7 + vfmadd213ps zmm24, zmm11, zmm7 + vfmadd213ps zmm25, zmm11, zmm7 + vfmadd213ps zmm26, zmm11, zmm7 + vfmadd213ps zmm27, zmm11, zmm7 + vfmadd213ps zmm28, zmm11, zmm7 + vfmadd213ps zmm29, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vminps zmm28, zmm1, zmm28 + vminps zmm29, zmm1, zmm29 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + vmaxps zmm28, zmm0, zmm28 + vmaxps zmm29, zmm0, zmm29 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm5 + vmovups [rsi + 64], zmm21 + vmovups [rax], zmm12 + vmovups [rax + 64], zmm22 + vmovups [r15], zmm14 + vmovups [r15 + 64], zmm23 + vmovups [r14], zmm15 + vmovups [r14 + 64], zmm24 + vmovups [r12], zmm16 + vmovups [r12 + 64], zmm25 + vmovups [r10], zmm17 + vmovups [r10 + 64], zmm26 + vmovups [r13], zmm18 + vmovups [r13 + 64], zmm27 + vmovups [rbx], zmm19 + vmovups [rbx + 64], zmm28 + vmovups [rbp], zmm20 + vmovups [rbp + 64], zmm29 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm25 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm26 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm27 + vmovups ZMMWORD PTR [rbx]{k1}, zmm19 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm28 + vmovups ZMMWORD PTR [rbp]{k1}, zmm20 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm29 + +return: + add rsp, 1040 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/xnnpack/gemm.h b/src/xnnpack/gemm.h index df4dfeb56d6e..fe2017be2d01 100644 --- a/src/xnnpack/gemm.h +++ b/src/xnnpack/gemm.h @@ -342,6 +342,46 @@ DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_6x8__avx_br DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_6x16__avx_broadcast) DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_7x8__avx_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane) + +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast) + +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast) + DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x8__fma3_broadcast) DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x16__fma3_broadcast) DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_3x16__fma3_broadcast) @@ -2486,6 +2526,44 @@ DECLARE_QD8_F16_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc8w_gemm_minmax_u const union xnn_f32_minmax_params params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)], \ const struct xnn_qd8_quantization_params quantization_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]); +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnni) + +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni) + +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_lane) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_lane) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_lane) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_lane) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_lane) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_lane) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_lane) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_lane) + DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch32_neondot_cortex_a55) DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_cortex_a55) diff --git a/test/f32-gemm-minmax-2.cc b/test/f32-gemm-minmax-2.cc index 7225446bcd14..962a5614f859 100644 --- a/test/f32-gemm-minmax-2.cc +++ b/test/f32-gemm-minmax-2.cc @@ -1188,6 +1188,166 @@ std::vector CreateTests2( return info.param.test_name; }); + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_1X8__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_2X8__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_3X8__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_1X16__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_2X16__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_3X16__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_4X16__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_5X16__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + INSTANTIATE_TEST_SUITE_P( F32_GEMM_MINMAX_4X8__ASM_AARCH64_NEONFMA_LD64, GemmTest, testing::ValuesIn(CreateTests1( @@ -2005,7 +2165,293 @@ std::vector CreateTests2( [](const testing::TestParamInfo& info) { return info.param.test_name; }); +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_1X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_3X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_5X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_7X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/7, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_8X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/8, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_11X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/11, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_1X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/1, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_3X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/3, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_5X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/5, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_7X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/7, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_8X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/8, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_11X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/11, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_2X64__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/2, /*nr=*/64, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_4X64__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/4, /*nr=*/64, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 INSTANTIATE_TEST_SUITE_P( F32_GEMM_MINMAX_4X16__FMA3_BROADCAST, GemmTest, testing::ValuesIn(CreateTests1( diff --git a/test/f32-gemm-minmax.cc b/test/f32-gemm-minmax.cc index 6d9a7cc824b1..b62809225bf6 100644 --- a/test/f32-gemm-minmax.cc +++ b/test/f32-gemm-minmax.cc @@ -1231,6 +1231,46 @@ std::vector CreateTests2( return info.param.test_name; }); + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_4X8__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_5X8__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + INSTANTIATE_TEST_SUITE_P( F32_GEMM_MINMAX_4X8__ASM_AARCH64_NEONFMA_LD128, GemmTest, testing::ValuesIn(CreateTests1( @@ -2332,7 +2372,273 @@ std::vector CreateTests2( [](const testing::TestParamInfo& info) { return info.param.test_name; }); +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_2X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_4X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_6X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_9X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/9, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_10X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/10, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_2X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/2, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_4X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/4, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_6X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/6, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_9X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/9, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_10X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/10, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_1X64__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/1, /*nr=*/64, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_3X64__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/3, /*nr=*/64, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_5X64__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/5, /*nr=*/64, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 INSTANTIATE_TEST_SUITE_P( F32_GEMM_MINMAX_4X8__FMA3_BROADCAST, GemmTest, testing::ValuesIn(CreateTests1( diff --git a/test/f32-gemm-minmax.yaml b/test/f32-gemm-minmax.yaml index cd15dfee194c..bd67b87c1ed2 100644 --- a/test/f32-gemm-minmax.yaml +++ b/test/f32-gemm-minmax.yaml @@ -180,6 +180,48 @@ pack: xnn_pack_f32_gemm_goi_w k-block: 8 pipelined: true + +- name: xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 + - name: xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld64 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_f32_gemm_goi_w @@ -570,6 +612,114 @@ init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_f32_gemm_goi_w k-block: 4 +- name: xnn_f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 - name: xnn_f32_gemm_minmax_ukernel_4x8__fma3_broadcast init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_f32_gemm_goi_w diff --git a/test/qd8-f32-qc8w-gemm-minmax-2.cc b/test/qd8-f32-qc8w-gemm-minmax-2.cc index e3dc9c55d21c..6b00ff50cf63 100644 --- a/test/qd8-f32-qc8w-gemm-minmax-2.cc +++ b/test/qd8-f32-qc8w-gemm-minmax-2.cc @@ -640,6 +640,232 @@ std::vector CreateTests1( #endif // XNN_ENABLE_AVX512AMX && (XNN_ARCH_X86 || XNN_ARCH_X86_64) +#if XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_1X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/1, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_5X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/5, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_9X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/9, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_10X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/10, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_1X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_5X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_9X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_10X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_2X64C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/2, /*nr=*/64, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + + +#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_2X8C4__ASM_AARCH64_NEONDOT_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_2X16C4__ASM_AARCH64_NEONDOT_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + + #if XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64 INSTANTIATE_TEST_SUITE_P( QD8_F32_QC8W_GEMM_MINMAX_1X8C8__NEONI8MM, GemmTest, diff --git a/test/qd8-f32-qc8w-gemm-minmax-3.cc b/test/qd8-f32-qc8w-gemm-minmax-3.cc index d5e5346abe20..67a4b1f57d7f 100644 --- a/test/qd8-f32-qc8w-gemm-minmax-3.cc +++ b/test/qd8-f32-qc8w-gemm-minmax-3.cc @@ -560,6 +560,192 @@ std::vector CreateTests1( #endif // XNN_ENABLE_AVX512AMX && (XNN_ARCH_X86 || XNN_ARCH_X86_64) +#if XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_2X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/2, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_6X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/6, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_2X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_6X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/6, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_1X64C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/1, /*nr=*/64, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_5X64C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/5, /*nr=*/64, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + + +#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_1X8C4__ASM_AARCH64_NEONDOT_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_1X16C4__ASM_AARCH64_NEONDOT_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_4X16C4__ASM_AARCH64_NEONDOT_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + + #if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM && XNN_ENABLE_ASSEMBLY INSTANTIATE_TEST_SUITE_P( QD8_F32_QC8W_GEMM_MINMAX_4X8C4__ASM_AARCH32_NEONDOT_CORTEX_A55, GemmTest, diff --git a/test/qd8-f32-qc8w-gemm-minmax-4.cc b/test/qd8-f32-qc8w-gemm-minmax-4.cc index 3ad716ddc062..b7df81fb9dbc 100644 --- a/test/qd8-f32-qc8w-gemm-minmax-4.cc +++ b/test/qd8-f32-qc8w-gemm-minmax-4.cc @@ -600,6 +600,152 @@ std::vector CreateTests1( #endif // XNN_ENABLE_AVX512AMX && (XNN_ARCH_X86 || XNN_ARCH_X86_64) +#if XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_3X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/3, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_7X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/7, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_3X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_7X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_4X64C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/4, /*nr=*/64, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + + +#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_4X8C4__ASM_AARCH64_NEONDOT_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_3X16C4__ASM_AARCH64_NEONDOT_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + + #if XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64 INSTANTIATE_TEST_SUITE_P( QD8_F32_QC8W_GEMM_MINMAX_1X32C8__NEONI8MM, GemmTest, diff --git a/test/qd8-f32-qc8w-gemm-minmax.cc b/test/qd8-f32-qc8w-gemm-minmax.cc index c25529c9e1c9..dcb2e30980a8 100644 --- a/test/qd8-f32-qc8w-gemm-minmax.cc +++ b/test/qd8-f32-qc8w-gemm-minmax.cc @@ -600,6 +600,172 @@ std::vector CreateTests1( #endif // XNN_ENABLE_AVX512AMX && (XNN_ARCH_X86 || XNN_ARCH_X86_64) +#if XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_4X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/4, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_8X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/8, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_11X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/11, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_4X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_8X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_11X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/11, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_3X64C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/3, /*nr=*/64, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + + +#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_3X8C4__ASM_AARCH64_NEONDOT_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/3, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + + #if XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64 INSTANTIATE_TEST_SUITE_P( QD8_F32_QC8W_GEMM_MINMAX_1X16C8__NEONI8MM, GemmTest, diff --git a/test/qd8-f32-qc8w-gemm-minmax.yaml b/test/qd8-f32-qc8w-gemm-minmax.yaml index d67ec399968c..3967610f9ea5 100644 --- a/test/qd8-f32-qc8w-gemm-minmax.yaml +++ b/test/qd8-f32-qc8w-gemm-minmax.yaml @@ -55,6 +55,176 @@ pack: xnn_pack_qs8_gemm_goi_w k-block: 64 +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True + +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True + +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True + +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 # AArch32 assembly - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch32_neondot_cortex_a55 init: xnn_init_f32_minmax_scalar_params diff --git a/tools/update-microkernels.py b/tools/update-microkernels.py index e6ecd944c78b..c41b7f6e2348 100755 --- a/tools/update-microkernels.py +++ b/tools/update-microkernels.py @@ -93,6 +93,7 @@ 'wasm32', 'wasmsimd32', 'wasmrelaxedsimd32', + 'amd64', }) _MICROKERNEL_NAME_REGEX = re.compile( @@ -273,7 +274,7 @@ def main(args): for component in basename.split('-'): if component in _ARCH_LIST: arch = component - elif component in _ISA_LIST: + if component in _ISA_LIST: isa = _ISA_MAP.get(component, component) key = isa if arch is None else f'{isa}_{arch}' c_microkernels_per_isa[key].append(filepath) diff --git a/tools/xnncommon.py b/tools/xnncommon.py index 38551a0545c8..3583ec40bdf0 100644 --- a/tools/xnncommon.py +++ b/tools/xnncommon.py @@ -29,6 +29,7 @@ def _remove_duplicate_newlines(text): "aarch64": "XNN_ARCH_ARM64", "x86-32": "XNN_ARCH_X86", "x86-64": "XNN_ARCH_X86_64", + "amd64": "XNN_ARCH_X86_64", "hexagon": "XNN_ARCH_HEXAGON", "riscv": "XNN_ARCH_RISCV", "wasm": "XNN_ARCH_WASM", @@ -298,4 +299,4 @@ def make_multiline_macro(x): lines = x.strip().split('\n') max_len = max([len(i) for i in lines]) lines = [i.ljust(max_len) + "\\" for i in lines] - return "\n".join(lines)[:-1].strip() + "\n" \ No newline at end of file + return "\n".join(lines)[:-1].strip() + "\n"