From 9d37b8692d4a587aa699f80805a9579777e18c8c Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Fri, 25 Oct 2024 16:03:18 +0200 Subject: [PATCH 01/19] Add GCC to compiler check --- src/CMakeLists.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9cead70f..bac84596 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -4,7 +4,7 @@ set(GGML_SOURCES_BITNET ggml-bitnet-lut.cpp) include_directories(3rdparty/llama.cpp/ggml/include) -if ((NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") OR -(NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")) - message(FATAL_ERROR "Clang is required for Bitnet.cpp compilation") -endif() \ No newline at end of file +if (NOT (CMAKE_C_COMPILER_ID MATCHES "Clang" OR CMAKE_C_COMPILER_ID STREQUAL "GNU") OR + NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")) + message(FATAL_ERROR "Clang or GCC is required for Bitnet.cpp compilation") +endif() From 141ddfd4fe067e0f5e1d74e4eeabca2e2ab260fe Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Fri, 25 Oct 2024 16:03:18 +0200 Subject: [PATCH 02/19] Fix compiler errors on GCC --- setup_env.py | 1 - utils/codegen_tl2.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/setup_env.py b/setup_env.py index b9bf5fc5..8a9c4b46 100644 --- a/setup_env.py +++ b/setup_env.py @@ -34,7 +34,6 @@ OS_EXTRA_ARGS = { "Windows":["-T", "ClangCL"], - "Linux": ["-DCMAKE_C_COMPILER=clang", "-DCMAKE_CXX_COMPILER=clang++"] } ARCH_ALIAS = { diff --git a/utils/codegen_tl2.py b/utils/codegen_tl2.py index 44d24187..4d940812 100644 --- a/utils/codegen_tl2.py +++ b/utils/codegen_tl2.py @@ -105,7 +105,7 @@ def gen_ctor_code(): template\n\ inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {\n\ #if defined __AVX2__\n\ - __m256 vec_lut[16];\n\ + __m256i vec_lut[16];\n\ const __m256i vec_bi = _mm256_set_epi32(84, 72, 60, 48, 36, 24, 12, 0);\n\ float scales = *lut_scales;\n\ __m256i shuffle_mask = _mm256_set_epi8(\n\ @@ -191,7 +191,7 @@ def gen_ctor_code(): template\n\ inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {\n\ #if defined __AVX2__\n\ - __m256 vec_lut[16];\n\ + __m256i vec_lut[16];\n\ const __m256i vec_bi = _mm256_set_epi32(56, 48, 40, 32, 24, 16, 8, 0);\n\ float scales = *lut_scales;\n\ __m256i shuffle_mask = _mm256_set_epi8(\n\ From 5ec277e81baf5e00bb550d6b68f2f9e5254a7610 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Fri, 25 Oct 2024 22:16:14 +0200 Subject: [PATCH 03/19] Put tl2 ctor in its own file --- src/ggml-bitnet-mad.cpp | 6 +- utils/codegen_tl2.py | 274 +--------------------------------------- 2 files changed, 5 insertions(+), 275 deletions(-) diff --git a/src/ggml-bitnet-mad.cpp b/src/ggml-bitnet-mad.cpp index eeca82b1..8deff086 100644 --- a/src/ggml-bitnet-mad.cpp +++ b/src/ggml-bitnet-mad.cpp @@ -6,8 +6,8 @@ #include #include -#define QK_I2_S 128 -#define QK_I2 128 +static constexpr auto QK_I2_S = 128; +static constexpr auto QK_I2 = 128; #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) #include @@ -360,4 +360,4 @@ void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t b *s = (float)sumi; #endif -} \ No newline at end of file +} diff --git a/utils/codegen_tl2.py b/utils/codegen_tl2.py index 4d940812..678081d8 100644 --- a/utils/codegen_tl2.py +++ b/utils/codegen_tl2.py @@ -1,280 +1,10 @@ import argparse import os +from pathlib import Path from configparser import ConfigParser def gen_ctor_code(): - kernel_code = "\n\ -#include \"ggml-bitnet.h\"\n\ -#include \n\ -#include \n\ -#define GGML_BITNET_MAX_NODES 8192\n\ -static bool initialized = false;\n\ -static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;\n\ -static size_t bitnet_tensor_extras_index = 0;\n\ -static void * aligned_malloc(size_t size) {\n\ -#if defined(_WIN32)\n\ - return _aligned_malloc(size, 64);\n\ -#else\n\ - void * ptr = nullptr;\n\ - posix_memalign(&ptr, 64, size);\n\ - return ptr;\n\ -#endif\n\ -}\n\ -\n\ -static void aligned_free(void * ptr) {\n\ -#if defined(_WIN32)\n\ - _aligned_free(ptr);\n\ -#else\n\ - free(ptr);\n\ -#endif\n\ -}\n\ -#define BK2 32\n\ -#if defined __AVX2__\n\ -inline void _mm256_merge_epi32(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh)\n\ -{\n\ - __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0));\n\ - __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0));\n\ - *vl = _mm256_unpacklo_epi32(va, vb);\n\ - *vh = _mm256_unpackhi_epi32(va, vb);\n\ -}\n\ -inline void _mm256_merge_epi64(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh)\n\ -{\n\ - __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0));\n\ - __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0));\n\ - *vl = _mm256_unpacklo_epi64(va, vb);\n\ - *vh = _mm256_unpackhi_epi64(va, vb);\n\ -}\n\ -inline void _mm256_merge_si128(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh)\n\ -{\n\ - *vl = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 2, 0, 0));\n\ - *vh = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 3, 0, 1));\n\ -}\n\ -inline void Transpose_8_8(\n\ - __m256i *v0,\n\ - __m256i *v1,\n\ - __m256i *v2,\n\ - __m256i *v3,\n\ - __m256i *v4,\n\ - __m256i *v5,\n\ - __m256i *v6,\n\ - __m256i *v7)\n\ -{\n\ - __m256i w0, w1, w2, w3, w4, w5, w6, w7;\n\ - __m256i x0, x1, x2, x3, x4, x5, x6, x7;\n\ - _mm256_merge_epi32(*v0, *v1, &w0, &w1);\n\ - _mm256_merge_epi32(*v2, *v3, &w2, &w3);\n\ - _mm256_merge_epi32(*v4, *v5, &w4, &w5);\n\ - _mm256_merge_epi32(*v6, *v7, &w6, &w7);\n\ - _mm256_merge_epi64(w0, w2, &x0, &x1);\n\ - _mm256_merge_epi64(w1, w3, &x2, &x3);\n\ - _mm256_merge_epi64(w4, w6, &x4, &x5);\n\ - _mm256_merge_epi64(w5, w7, &x6, &x7);\n\ - _mm256_merge_si128(x0, x4, v0, v1);\n\ - _mm256_merge_si128(x1, x5, v2, v3);\n\ - _mm256_merge_si128(x2, x6, v4, v5);\n\ - _mm256_merge_si128(x3, x7, v6, v7);\n\ -}\n\ -#endif\n\ -inline int32_t per_tensor_quant(int k, void* lut_scales_, void* b_) {\n\ - bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\ - bitnet_float_type* b = (bitnet_float_type*)b_;\n\ -#if defined __AVX2__\n\ - __m256 max_vec = _mm256_set1_ps(0.f);\n\ - const __m256 vec_sign = _mm256_set1_ps(-0.0f);\n\ - for (int i = 0; i < k / 8; i++) {\n\ - __m256 vec_b = _mm256_loadu_ps(b + i * 8);\n\ - __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);\n\ - max_vec = _mm256_max_ps(vec_babs, max_vec);\n\ - }\n\ - __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));\n\ - max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));\n\ - max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));\n\ - float scales = 127 / _mm_cvtss_f32(max1);\n\ - *lut_scales = scales;\n\ -#endif\n\ - return 0;\n\ -}\n\ -inline int32_t partial_max_reset(int32_t bs, void* lut_scales_) {\n\ - bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\ - #pragma unroll\n\ - for (int i=0; i< bs; i++) {\n\ - lut_scales[i] = 0.0;\n\ - }\n\ - return 0;\n\ -}\n\ -template\n\ -inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {\n\ -#if defined __AVX2__\n\ - __m256i vec_lut[16];\n\ - const __m256i vec_bi = _mm256_set_epi32(84, 72, 60, 48, 36, 24, 12, 0);\n\ - float scales = *lut_scales;\n\ - __m256i shuffle_mask = _mm256_set_epi8(\n\ - 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01,\n\ - 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00,\n\ - 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01,\n\ - 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00\n\ - );\n\ -#pragma unroll\n\ - for (int k = 0; k < act_k / 24; ++k) {\n\ - __m256 vec_b0 = _mm256_i32gather_ps(b + k * 24 + 0, vec_bi, 1);\n\ - __m256 vec_b1 = _mm256_i32gather_ps(b + k * 24 + 1, vec_bi, 1);\n\ - __m256 vec_b2 = _mm256_i32gather_ps(b + k * 24 + 2, vec_bi, 1);\n\ -\n\ - __m256i vec_b0i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\ - __m256i vec_b1i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\ - __m256i vec_b2i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b2, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\ -\n\ - vec_lut[15] = _mm256_setzero_si256();\n\ - vec_lut[14] = _mm256_setzero_si256();\n\ - vec_lut[13] = vec_b0i;\n\ - vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b1i);\n\ - vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b2i);\n\ - vec_lut[12] = vec_b0i;\n\ - vec_lut[12] = _mm256_add_epi32(vec_lut[12], vec_b1i);\n\ - vec_lut[11] = vec_b0i;\n\ - vec_lut[11] = _mm256_add_epi32(vec_lut[11], vec_b1i);\n\ - vec_lut[11] = _mm256_sub_epi32(vec_lut[11], vec_b2i);\n\ - vec_lut[10] = vec_b0i;\n\ - vec_lut[10] = _mm256_add_epi32(vec_lut[10], vec_b2i);\n\ - vec_lut[9] = vec_b0i;\n\ - vec_lut[8] = vec_b0i;\n\ - vec_lut[8] = _mm256_sub_epi32(vec_lut[8], vec_b2i);\n\ - vec_lut[7] = vec_b0i;\n\ - vec_lut[7] = _mm256_sub_epi32(vec_lut[7], vec_b1i);\n\ - vec_lut[7] = _mm256_add_epi32(vec_lut[7], vec_b2i);\n\ - vec_lut[6] = vec_b0i;\n\ - vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1i);\n\ - vec_lut[5] = vec_b0i;\n\ - vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b1i);\n\ - vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b2i);\n\ - vec_lut[4] = vec_b1i;\n\ - vec_lut[4] = _mm256_add_epi32(vec_lut[4], vec_b2i);\n\ - vec_lut[3] = vec_b1i;\n\ - vec_lut[2] = vec_b1i;\n\ - vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b2i);\n\ - vec_lut[1] = vec_b2i;\n\ - vec_lut[0] = _mm256_setzero_si256();\n\ - __m256i ix[16];\n\ -\n\ -#pragma unroll\n\ - for (int g = 0; g < 16; ++g) {\n\ - ix[g] = vec_lut[g];\n\ - }\n\ -\n\ - Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7]));\n\ - Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15]));\n\ -\n\ -#pragma unroll\n\ - for (int g = 0; g < 8; ++g) {\n\ - ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]);\n\ - ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));\n\ - ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask);\n\ - ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));\n\ - }\n\ - int8_t* qlut_i8 = reinterpret_cast(qlut);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]);\n\ -\n\ - }\n\ -\n\ - *lut_scales = scales;\n\ -#endif\n\ - return 0;\n\ -}\n\ -\n\ -template\n\ -inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {\n\ -#if defined __AVX2__\n\ - __m256i vec_lut[16];\n\ - const __m256i vec_bi = _mm256_set_epi32(56, 48, 40, 32, 24, 16, 8, 0);\n\ - float scales = *lut_scales;\n\ - __m256i shuffle_mask = _mm256_set_epi8(\n\ - 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01,\n\ - 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00,\n\ - 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01,\n\ - 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00\n\ - );\n\ -#pragma unroll\n\ - for (int k = 0; k < act_k / 16; ++k) {\n\ - __m256 vec_b0f = _mm256_i32gather_ps(b + k * 16 + 0, vec_bi, 1);\n\ - __m256 vec_b1f = _mm256_i32gather_ps(b + k * 16 + 1, vec_bi, 1);\n\ -\n\ - __m256i vec_b0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\ - __m256i vec_b1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\ - vec_lut[15] = _mm256_setzero_si256();\n\ - vec_lut[14] = _mm256_setzero_si256();\n\ - vec_lut[13] = _mm256_setzero_si256();\n\ - vec_lut[12] = _mm256_setzero_si256();\n\ - vec_lut[11] = _mm256_setzero_si256();\n\ - vec_lut[10] = _mm256_setzero_si256();\n\ - vec_lut[9] = _mm256_setzero_si256();\n\ - vec_lut[8] = vec_b0;\n\ - vec_lut[8] = _mm256_add_epi32(vec_lut[8], vec_b1);\n\ - vec_lut[7] = vec_b0;\n\ - vec_lut[6] = vec_b0;\n\ - vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1);\n\ - vec_lut[5] = vec_b1;\n\ - vec_lut[4] = _mm256_setzero_si256();\n\ - vec_lut[3] = _mm256_setzero_si256();\n\ - vec_lut[3] = _mm256_sub_epi32(vec_lut[3], vec_b1);\n\ - vec_lut[2] = _mm256_setzero_si256();\n\ - vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b0);\n\ - vec_lut[2] = _mm256_add_epi32(vec_lut[2], vec_b1);\n\ - vec_lut[1] = _mm256_setzero_si256();\n\ - vec_lut[1] = _mm256_sub_epi32(vec_lut[1], vec_b0);\n\ - vec_lut[0] = _mm256_setzero_si256();\n\ - vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b0);\n\ - vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b1);\n\ -\n\ - __m256i ix[16];\n\ -#pragma unroll\n\ - for (int g = 0; g < 16; ++g) {\n\ - ix[g] = vec_lut[g];\n\ - }\n\ -\n\ - Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7]));\n\ - Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15]));\n\ -\n\ -#pragma unroll\n\ - for (int g = 0; g < 8; ++g) {\n\ - ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]);\n\ - ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));\n\ - ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask);\n\ - ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));\n\ - }\n\ -\n\ - int8_t* qlut_i8 = reinterpret_cast(qlut);\n\ -\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]);\n\ -\n\ - }\n\ - *lut_scales = scales;\n\ -#endif\n\ - return 0;\n\ -}\n\ -static bool is_type_supported(enum ggml_type type) {\n\ - if (type == GGML_TYPE_Q4_0 ||\n\ - type == GGML_TYPE_TL2) {\n\ - return true;\n\ - } else {\n\ - return false;\n\ - }\n\ -}\n\ -" - return kernel_code + return "\n" + (Path(__file__).parent / "tl2_ctor.h").read_text(encoding='utf-8') def gen_tbl_impl(pre, BM, BK, bm, k_list): From 28b7cd11c1f20822a73902798faeffbd76df815a Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Fri, 25 Oct 2024 22:29:52 +0200 Subject: [PATCH 04/19] Move to `templates` --- utils/codegen_tl2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/codegen_tl2.py b/utils/codegen_tl2.py index 678081d8..b03d8ee9 100644 --- a/utils/codegen_tl2.py +++ b/utils/codegen_tl2.py @@ -4,7 +4,7 @@ from configparser import ConfigParser def gen_ctor_code(): - return "\n" + (Path(__file__).parent / "tl2_ctor.h").read_text(encoding='utf-8') + return "\n" + (Path(__file__).parent / "templates" / "tl2_ctor.h").read_text(encoding='utf-8') def gen_tbl_impl(pre, BM, BK, bm, k_list): From 7fb4b8e6c7d5d4c13796f60adcdd4ca2b814ad04 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Fri, 25 Oct 2024 22:51:24 +0200 Subject: [PATCH 05/19] Add tl2 table impl template --- utils/codegen_tl2.py | 257 +------------------------------ utils/templates/tl2_table_impl.h | 245 +++++++++++++++++++++++++++++ 2 files changed, 251 insertions(+), 251 deletions(-) create mode 100644 utils/templates/tl2_table_impl.h diff --git a/utils/codegen_tl2.py b/utils/codegen_tl2.py index b03d8ee9..c0d281db 100644 --- a/utils/codegen_tl2.py +++ b/utils/codegen_tl2.py @@ -2,262 +2,17 @@ import os from pathlib import Path from configparser import ConfigParser +from jinja2 import Environment, FileSystemLoader def gen_ctor_code(): return "\n" + (Path(__file__).parent / "templates" / "tl2_ctor.h").read_text(encoding='utf-8') def gen_tbl_impl(pre, BM, BK, bm, k_list): - - kernel_code = "\ -#include \n\ -\n\ -#define BM{0} {1}\n\ -#define BBK{0} {2}\n\ -template\n\ -inline void three_tbl_impl_{0}(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) {{\n\ -".format(pre, BM, BK) - - kernel_code = "".join([kernel_code, "\ -#ifdef __AVX2__\n\ - const __m256i vec_mask = _mm256_set1_epi8(0x0f);\n\ - const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000);\n\ - const __m256i vec_zero = _mm256_set1_epi8(0x00);\n\ - const __m256i vec_one = _mm256_set1_epi8(0xff);\n\ - const int KK = BBK{0} / 3;\n\ -#pragma unroll\n\ - for (int i = 0; i < BM{0}; i += 32) {{\n\ - __m256i vec_as[KK / 2];\n\ - __m256i vec_signs[KK / 8];\n\ - #pragma unroll\n\ - for (int ai = 0; ai < KK / 2; ai++) {{\n\ - vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32));\n\ - }}\n\ - #pragma unroll\n\ - for (int as = 0; as < KK / 8; as++) {{\n\ - vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32));\n\ - }}\n\ -#pragma unroll\n\ - for (int bs = 0; bs < batch_size; bs++) {{\n\ - __m256i vec_c0 = _mm256_setzero_si256();\n\ - __m256i vec_c1 = _mm256_setzero_si256();\n\ -#pragma unroll\n\ - for (int k = 0; k < KK / 8; k++) {{\n\ - __m256i vec_sign = vec_signs[k];\n\ - __m256i vec_a_0 = vec_as[k * 4 + 0];\n\ - __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs));\n\ - __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs));\n\ - __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs));\n\ - __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs));\n\ - __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15);\n\ - __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15);\n\ - __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask);\n\ - __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0);\n\ - __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0);\n\ - __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15);\n\ - __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15);\n\ - __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask);\n\ - __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0);\n\ - __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0);\n\ - __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0);\n\ - __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0);\n\ - __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0);\n\ - __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0);\n\ - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0);\n\ - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0);\n\ - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0);\n\ - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0);\n\ - __m256i vec_a_1 = vec_as[k * 4 + 1];\n\ - __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs));\n\ - __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs));\n\ - __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs));\n\ - __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs));\n\ - __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15);\n\ - __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15);\n\ - __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask);\n\ - __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1);\n\ - __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1);\n\ - __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15);\n\ - __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15);\n\ - __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask);\n\ - __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1);\n\ - __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1);\n\ - __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1);\n\ - __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1);\n\ - __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1);\n\ - __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1);\n\ - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1);\n\ - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1);\n\ - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1);\n\ - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1);\n\ - __m256i vec_a_2 = vec_as[k * 4 + 2];\n\ - __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs));\n\ - __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs));\n\ - __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs));\n\ - __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs));\n\ - __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15);\n\ - __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15);\n\ - __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask);\n\ - __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2);\n\ - __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2);\n\ - __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15);\n\ - __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15);\n\ - __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask);\n\ - __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2);\n\ - __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2);\n\ - __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2);\n\ - __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2);\n\ - __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2);\n\ - __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2);\n\ - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2);\n\ - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2);\n\ - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2);\n\ - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2);\n\ - __m256i vec_a_3 = vec_as[k * 4 + 3];\n\ - __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs));\n\ - __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs));\n\ - __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs));\n\ - __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs));\n\ - __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15);\n\ - __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15);\n\ - __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask);\n\ - __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3);\n\ - __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3);\n\ - __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15);\n\ - __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15);\n\ - __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask);\n\ - __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3);\n\ - __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3);\n\ - __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3);\n\ - __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3);\n\ - __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3);\n\ - __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3);\n\ - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3);\n\ - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3);\n\ - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3);\n\ - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3);\n\ - }}\n\ - __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM{0} * bs));\n\ - __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{0} * bs));\n\ - __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{0} * bs));\n\ - __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{0} * bs));\n\ - vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0)));\n\ - vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1)));\n\ - vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1)));\n\ - vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1)));\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM{0} * bs), vec_gc0);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{0} * bs), vec_gc1);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{0} * bs), vec_gc2);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{0} * bs), vec_gc3);\n\ - }}\n\ - }}\n\ -#endif\n\ -}}\n\ -\n\ -template\n\ -inline int32_t two_tbl_impl{0}(int32_t* c, int8_t* lut, uint8_t* a) {{\n\ -#ifdef __AVX2__\n\ - const __m256i vec_mask = _mm256_set1_epi8(0x0f);\n\ - const int KK = BK2 / 2;\n\ -#pragma unroll\n\ - for (int i = 0; i < BM{0}; i += 32) {{\n\ - __m256i vec_as[KK / 2];\n\ - #pragma unroll\n\ - for (int ai = 0; ai < KK / 2; ai++) {{\n\ - vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32));\n\ - }}\n\ -#pragma unroll\n\ - for (int bs = 0; bs < batch_size; bs++) {{\n\ - __m256i vec_c0 = _mm256_setzero_si256();\n\ - __m256i vec_c1 = _mm256_setzero_si256();\n\ -#pragma unroll\n\ - for (int k = 0; k < KK / 8; k++) {{\n\ - #pragma unroll\n\ - for (int j = 0; j < 4; j++) {{\n\ - __m256i vec_a = vec_as[k * 4 + j];\n\ -\n\ - __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs));\n\ - __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs));\n\ - __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs));\n\ - __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs));\n\ -\n\ - __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask);\n\ - __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top);\n\ - __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top);\n\ -\n\ - __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask);\n\ - __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot);\n\ - __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot);\n\ -\n\ - __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec);\n\ - __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec);\n\ - __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec);\n\ - __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec);\n\ - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi);\n\ - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi);\n\ - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo);\n\ - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); \n\ - }}\n\ - }}\n\ -\n\ - __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM{0} * bs));\n\ - __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{0} * bs));\n\ - __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{0} * bs));\n\ - __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{0} * bs));\n\ -\n\ - vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0)));\n\ - vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1)));\n\ - vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1)));\n\ - vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1)));\n\ -\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM{0} * bs), vec_gc0);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{0} * bs), vec_gc1);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{0} * bs), vec_gc2);\n\ - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{0} * bs), vec_gc3);\n\ - }}\n\ - }}\n\ -#endif\n\ - return 0;\n\ -}}\n\ -\n\ -template\n\ -int32_t three_qgemm_lut_{0}(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\ - alignas(32) uint32_t CBits[BATCH_SIZE * BM{0}];\n\ - memset(&(CBits[0]), 0, BATCH_SIZE * BM{0} * sizeof(int32_t));\n\ -#pragma unroll\n\ - for (int32_t k_outer = 0; k_outer < {1} / BBK{0}; ++k_outer) {{\n\ - three_tbl_impl_{0}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{0} / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{0} / 3 / 2 * BM{0})])), (&(((uint8_t*)sign)[(k_outer * BBK{0} / 3 / 8 * BM{0})])));\n\ - }}\n\ -#pragma unroll\n\ - for (int bs = 0; bs < BATCH_SIZE; bs++) {{\n\ -#pragma unroll\n\ - for (int i = 0; i < BM{0}; i++) {{\n\ - ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM{0}]);\n\ - }}\n\ - }}\n\ - return 0;\n\ -}}\n\ -\n\ -template\n\ -int32_t two_qgemm_lut_{0}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\ - alignas(32) uint32_t CBits[BATCH_SIZE * BM{0}];\n\ - memset(&(CBits[0]), 0, BATCH_SIZE * BM{0} * sizeof(int32_t));\n\ -#pragma unroll\n\ - for (int32_t k_outer = 0; k_outer < {2} / 32; ++k_outer) {{\n\ - two_tbl_impl{0}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM{0})])));\n\ - }}\n\ -#pragma unroll\n\ - for (int bs = 0; bs < BATCH_SIZE; bs++) {{\n\ -#pragma unroll\n\ - for (int i = 0; i < BM{0}; i++) {{\n\ - ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM{0}]);\n\ - ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0];\n\ - }}\n\ - }}\n\ - return 0;\n\ -}}\n\ -\n\ -".format(pre, k_list[1], k_list[0])]) - return kernel_code + env = Environment( + loader=FileSystemLoader(Path(__file__).parent / "templates"), + ) + template = env.get_template("tl2_table_impl.h") + return "\n" + template.render(pre=pre, BM=BM, BK=BK, bm=bm, k_list=k_list) def gen_top_api(kernel_shapes, k_list): diff --git a/utils/templates/tl2_table_impl.h b/utils/templates/tl2_table_impl.h new file mode 100644 index 00000000..624b4bc9 --- /dev/null +++ b/utils/templates/tl2_table_impl.h @@ -0,0 +1,245 @@ +#include + +#define BM{{ pre }} {{ BM }} +#define BBK{{ pre }} {{ BK }} +template +inline void three_tbl_impl_{{ pre }}(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) { + + +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000); + const __m256i vec_zero = _mm256_set1_epi8(0x00); + const __m256i vec_one = _mm256_set1_epi8(0xff); + const int KK = BBK{{ pre }} / 3; +#pragma unroll + for (int i = 0; i < BM{{ pre }}; i += 32) { + __m256i vec_as[KK / 2]; + __m256i vec_signs[KK / 8]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } + #pragma unroll + for (int as = 0; as < KK / 8; as++) { + vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + __m256i vec_sign = vec_signs[k]; + __m256i vec_a_0 = vec_as[k * 4 + 0]; + __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15); + __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15); + __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask); + __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0); + __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0); + __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15); + __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15); + __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask); + __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0); + __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0); + __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0); + __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0); + __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0); + __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0); + __m256i vec_a_1 = vec_as[k * 4 + 1]; + __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15); + __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15); + __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask); + __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1); + __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1); + __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15); + __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15); + __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask); + __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1); + __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1); + __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1); + __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1); + __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1); + __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1); + __m256i vec_a_2 = vec_as[k * 4 + 2]; + __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15); + __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15); + __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask); + __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2); + __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2); + __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15); + __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15); + __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask); + __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2); + __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2); + __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2); + __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2); + __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2); + __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2); + __m256i vec_a_3 = vec_as[k * 4 + 3]; + __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15); + __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15); + __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask); + __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3); + __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3); + __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15); + __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15); + __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask); + __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3); + __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3); + __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3); + __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3); + __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3); + __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3); + } + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs)); + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs), vec_gc3); + } + } +#endif +} + +template +inline int32_t two_tbl_impl{{ pre }}(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const int KK = BK2 / 2; +#pragma unroll + for (int i = 0; i < BM{{ pre }}; i += 32) { + __m256i vec_as[KK / 2]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + __m256i vec_a = vec_as[k * 4 + j]; + + __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs)); + __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs)); + __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs)); + __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs)); + + __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask); + __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top); + __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top); + + __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask); + __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot); + __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot); + + __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec); + __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); + } + } + + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs)); + + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs), vec_gc3); + } + } +#endif + return 0; +} + +template +int32_t three_qgemm_lut_{{ pre }}(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM{{ pre }}]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM{{ pre }} * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < {{ k_list[1] }} / BBK{{ pre }}; ++k_outer) { + three_tbl_impl_{{ pre }}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{{ pre }} / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{{ pre }} / 3 / 2 * BM{{ pre }})])), (&(((uint8_t*)sign)[(k_outer * BBK{{ pre }} / 3 / 8 * BM{{ pre }})]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM{{ pre }}; i++) { + ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM{{ pre }}]); + } + } + return 0; +} + +template +int32_t two_qgemm_lut_{{ pre }}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM{{ pre }}]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM{{ pre }} * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < {{ k_list[0] }} / 32; ++k_outer) { + two_tbl_impl{{ pre }}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM{{ pre }})]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM{{ pre }}; i++) { + ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM{{ pre }}]); + ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0]; + } + } + return 0; +} From ef16ce526301dbc1079b4bd387a70addd40c5c93 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Fri, 25 Oct 2024 22:52:33 +0200 Subject: [PATCH 06/19] Add tl2 ctor template file --- utils/templates/tl2_ctor.h | 269 +++++++++++++++++++++++++++++++++++++ 1 file changed, 269 insertions(+) create mode 100644 utils/templates/tl2_ctor.h diff --git a/utils/templates/tl2_ctor.h b/utils/templates/tl2_ctor.h new file mode 100644 index 00000000..c65b1abb --- /dev/null +++ b/utils/templates/tl2_ctor.h @@ -0,0 +1,269 @@ +#include "ggml-bitnet.h" +#include +#include +#define GGML_BITNET_MAX_NODES 8192 +static bool initialized = false; +static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; +static size_t bitnet_tensor_extras_index = 0; +static void * aligned_malloc(size_t size) { +#if defined(_WIN32) + return _aligned_malloc(size, 64); +#else + void * ptr = nullptr; + posix_memalign(&ptr, 64, size); + return ptr; +#endif +} + +static void aligned_free(void * ptr) { +#if defined(_WIN32) + _aligned_free(ptr); +#else + free(ptr); +#endif +} +#define BK2 32 +#if defined __AVX2__ +inline void _mm256_merge_epi32(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) +{ + __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0)); + *vl = _mm256_unpacklo_epi32(va, vb); + *vh = _mm256_unpackhi_epi32(va, vb); +} +inline void _mm256_merge_epi64(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) +{ + __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0)); + *vl = _mm256_unpacklo_epi64(va, vb); + *vh = _mm256_unpackhi_epi64(va, vb); +} +inline void _mm256_merge_si128(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) +{ + *vl = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 2, 0, 0)); + *vh = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 3, 0, 1)); +} +inline void Transpose_8_8( + __m256i *v0, + __m256i *v1, + __m256i *v2, + __m256i *v3, + __m256i *v4, + __m256i *v5, + __m256i *v6, + __m256i *v7) +{ + __m256i w0, w1, w2, w3, w4, w5, w6, w7; + __m256i x0, x1, x2, x3, x4, x5, x6, x7; + _mm256_merge_epi32(*v0, *v1, &w0, &w1); + _mm256_merge_epi32(*v2, *v3, &w2, &w3); + _mm256_merge_epi32(*v4, *v5, &w4, &w5); + _mm256_merge_epi32(*v6, *v7, &w6, &w7); + _mm256_merge_epi64(w0, w2, &x0, &x1); + _mm256_merge_epi64(w1, w3, &x2, &x3); + _mm256_merge_epi64(w4, w6, &x4, &x5); + _mm256_merge_epi64(w5, w7, &x6, &x7); + _mm256_merge_si128(x0, x4, v0, v1); + _mm256_merge_si128(x1, x5, v2, v3); + _mm256_merge_si128(x2, x6, v4, v5); + _mm256_merge_si128(x3, x7, v6, v7); +} +#endif +inline int32_t per_tensor_quant(int k, void* lut_scales_, void* b_) { + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + bitnet_float_type* b = (bitnet_float_type*)b_; +#if defined __AVX2__ + __m256 max_vec = _mm256_set1_ps(0.f); + const __m256 vec_sign = _mm256_set1_ps(-0.0f); + for (int i = 0; i < k / 8; i++) { + __m256 vec_b = _mm256_loadu_ps(b + i * 8); + __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b); + max_vec = _mm256_max_ps(vec_babs, max_vec); + } + __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec)); + max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1)); + max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1)); + float scales = 127 / _mm_cvtss_f32(max1); + *lut_scales = scales; +#endif + return 0; +} +inline int32_t partial_max_reset(int32_t bs, void* lut_scales_) { + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + #pragma unroll + for (int i=0; i< bs; i++) { + lut_scales[i] = 0.0; + } + return 0; +} +template +inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) { +#if defined __AVX2__ + __m256i vec_lut[16]; + const __m256i vec_bi = _mm256_set_epi32(84, 72, 60, 48, 36, 24, 12, 0); + float scales = *lut_scales; + __m256i shuffle_mask = _mm256_set_epi8( + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00, + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00 + ); +#pragma unroll + for (int k = 0; k < act_k / 24; ++k) { + __m256 vec_b0 = _mm256_i32gather_ps(b + k * 24 + 0, vec_bi, 1); + __m256 vec_b1 = _mm256_i32gather_ps(b + k * 24 + 1, vec_bi, 1); + __m256 vec_b2 = _mm256_i32gather_ps(b + k * 24 + 2, vec_bi, 1); + + __m256i vec_b0i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i vec_b1i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i vec_b2i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b2, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + vec_lut[15] = _mm256_setzero_si256(); + vec_lut[14] = _mm256_setzero_si256(); + vec_lut[13] = vec_b0i; + vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b1i); + vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b2i); + vec_lut[12] = vec_b0i; + vec_lut[12] = _mm256_add_epi32(vec_lut[12], vec_b1i); + vec_lut[11] = vec_b0i; + vec_lut[11] = _mm256_add_epi32(vec_lut[11], vec_b1i); + vec_lut[11] = _mm256_sub_epi32(vec_lut[11], vec_b2i); + vec_lut[10] = vec_b0i; + vec_lut[10] = _mm256_add_epi32(vec_lut[10], vec_b2i); + vec_lut[9] = vec_b0i; + vec_lut[8] = vec_b0i; + vec_lut[8] = _mm256_sub_epi32(vec_lut[8], vec_b2i); + vec_lut[7] = vec_b0i; + vec_lut[7] = _mm256_sub_epi32(vec_lut[7], vec_b1i); + vec_lut[7] = _mm256_add_epi32(vec_lut[7], vec_b2i); + vec_lut[6] = vec_b0i; + vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1i); + vec_lut[5] = vec_b0i; + vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b1i); + vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b2i); + vec_lut[4] = vec_b1i; + vec_lut[4] = _mm256_add_epi32(vec_lut[4], vec_b2i); + vec_lut[3] = vec_b1i; + vec_lut[2] = vec_b1i; + vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b2i); + vec_lut[1] = vec_b2i; + vec_lut[0] = _mm256_setzero_si256(); + __m256i ix[16]; + +#pragma unroll + for (int g = 0; g < 16; ++g) { + ix[g] = vec_lut[g]; + } + + Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7])); + Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15])); + +#pragma unroll + for (int g = 0; g < 8; ++g) { + ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + } + int8_t* qlut_i8 = reinterpret_cast(qlut); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]); + + } + + *lut_scales = scales; +#endif + return 0; +} + +template +inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) { +#if defined __AVX2__ + __m256i vec_lut[16]; + const __m256i vec_bi = _mm256_set_epi32(56, 48, 40, 32, 24, 16, 8, 0); + float scales = *lut_scales; + __m256i shuffle_mask = _mm256_set_epi8( + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00, + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00 + ); +#pragma unroll + for (int k = 0; k < act_k / 16; ++k) { + __m256 vec_b0f = _mm256_i32gather_ps(b + k * 16 + 0, vec_bi, 1); + __m256 vec_b1f = _mm256_i32gather_ps(b + k * 16 + 1, vec_bi, 1); + + __m256i vec_b0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i vec_b1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + vec_lut[15] = _mm256_setzero_si256(); + vec_lut[14] = _mm256_setzero_si256(); + vec_lut[13] = _mm256_setzero_si256(); + vec_lut[12] = _mm256_setzero_si256(); + vec_lut[11] = _mm256_setzero_si256(); + vec_lut[10] = _mm256_setzero_si256(); + vec_lut[9] = _mm256_setzero_si256(); + vec_lut[8] = vec_b0; + vec_lut[8] = _mm256_add_epi32(vec_lut[8], vec_b1); + vec_lut[7] = vec_b0; + vec_lut[6] = vec_b0; + vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1); + vec_lut[5] = vec_b1; + vec_lut[4] = _mm256_setzero_si256(); + vec_lut[3] = _mm256_setzero_si256(); + vec_lut[3] = _mm256_sub_epi32(vec_lut[3], vec_b1); + vec_lut[2] = _mm256_setzero_si256(); + vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b0); + vec_lut[2] = _mm256_add_epi32(vec_lut[2], vec_b1); + vec_lut[1] = _mm256_setzero_si256(); + vec_lut[1] = _mm256_sub_epi32(vec_lut[1], vec_b0); + vec_lut[0] = _mm256_setzero_si256(); + vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b0); + vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b1); + + __m256i ix[16]; +#pragma unroll + for (int g = 0; g < 16; ++g) { + ix[g] = vec_lut[g]; + } + + Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7])); + Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15])); + +#pragma unroll + for (int g = 0; g < 8; ++g) { + ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + } + + int8_t* qlut_i8 = reinterpret_cast(qlut); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]); + + } + *lut_scales = scales; +#endif + return 0; +} +static bool is_type_supported(enum ggml_type type) { + if (type == GGML_TYPE_Q4_0 || + type == GGML_TYPE_TL2) { + return true; + } else { + return false; + } +} From 6b0670931914ce4d690fa9d591ac7f2129bc69b3 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Fri, 25 Oct 2024 22:53:00 +0200 Subject: [PATCH 07/19] Add tl2 top api template --- utils/codegen_tl2.py | 95 ++--------------------------------- utils/templates/tl2_top_api.h | 49 ++++++++++++++++++ 2 files changed, 54 insertions(+), 90 deletions(-) create mode 100644 utils/templates/tl2_top_api.h diff --git a/utils/codegen_tl2.py b/utils/codegen_tl2.py index c0d281db..3341348f 100644 --- a/utils/codegen_tl2.py +++ b/utils/codegen_tl2.py @@ -15,97 +15,12 @@ def gen_tbl_impl(pre, BM, BK, bm, k_list): return "\n" + template.render(pre=pre, BM=BM, BK=BK, bm=bm, k_list=k_list) def gen_top_api(kernel_shapes, k_list): + env = Environment( + loader=FileSystemLoader(Path(__file__).parent / "templates"), + ) - kernel_code = "void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT) {{\n\ - partial_max_reset(bs, (&(((float*)LUT_Scales)[0])));\n\ - if (m == {0} && two_k == {1} && three_k == {2}) {{\n\ - for (int32_t b = 0; b < bs; b++) {{\n\ - per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)])));\n\ - three_lut_ctor<{2}>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b])));\n\ - two_lut_ctor<{1}>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + {2}])), (&(((float*)LUT_Scales)[b])));\n\ - }}\n\ - }}\n\ -".format(kernel_shapes[0][0], k_list[0][0], k_list[0][1]) - for i in range(1, len(kernel_shapes)): - kernel_code = "".join([kernel_code, " else if (m == {0} && two_k == {1} && three_k == {2}) {{\n\ - for (int32_t b = 0; b < bs; b++) {{\n\ - per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)])));\n\ - three_lut_ctor<{2}>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b])));\n\ - two_lut_ctor<{1}>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + {2}])), (&(((float*)LUT_Scales)[b])));\n\ - }}\n\ - }}\n".format(kernel_shapes[i][0], k_list[i][0], k_list[i][1])]) - kernel_code = "".join([kernel_code, "}\n"]) - - - kernel_code = "".join([kernel_code, "void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\ - if (m == {0} && k == {1}) {{\n\ - if (BK == {2}) {{\n\ - if (bs == 1) {{\n\ - two_qgemm_lut_{4}<1>(A, LUT, Scales, LUT_Scales, C);\n\ - }} else if (bs == 8) {{\n\ - two_qgemm_lut_{4}<8>(A, LUT, Scales, LUT_Scales, C);\n\ - }} else if (bs == 32) {{\n\ - two_qgemm_lut_{4}<32>(A, LUT, Scales, LUT_Scales, C);\n\ - }} else if (bs == 128) {{\n\ - two_qgemm_lut_{4}<128>(A, LUT, Scales, LUT_Scales, C);\n\ - }} else if (bs == 256) {{\n\ - two_qgemm_lut_{4}<256>(A, LUT, Scales, LUT_Scales, C);\n\ - }} else if (bs == 512) {{\n\ - two_qgemm_lut_{4}<512>(A, LUT, Scales, LUT_Scales, C);\n\ - }}\n\ - }}\n\ - else if (BK == {3}) {{\n\ - if (bs == 1) {{\n\ - three_qgemm_lut_{4}<1>(A, sign, LUT, Scales, LUT_Scales, C);\n\ - }}else if (bs == 8) {{\n\ - three_qgemm_lut_{4}<8>(A, sign, LUT, Scales, LUT_Scales, C);\n\ - }}else if (bs == 32) {{\n\ - three_qgemm_lut_{4}<32>(A, sign, LUT, Scales, LUT_Scales, C);\n\ - }}else if (bs == 128) {{\n\ - three_qgemm_lut_{4}<128>(A, sign, LUT, Scales, LUT_Scales, C);\n\ - }}else if (bs == 256) {{\n\ - three_qgemm_lut_{4}<256>(A, sign, LUT, Scales, LUT_Scales, C);\n\ - }}else if (bs == 512) {{\n\ - three_qgemm_lut_{4}<512>(A, sign, LUT, Scales, LUT_Scales, C);\n\ - }}\n\ - }}\n\ - }}\n\ -".format(kernel_shapes[0][0], kernel_shapes[0][1], k_list[0][0], k_list[0][1], "{}_{}".format(kernel_shapes[0][0], kernel_shapes[0][1]))]) - for i in range(1, len(kernel_shapes)): - kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\ - if (BK == {2}) {{\n\ - if (bs == 1) {{\n\ - two_qgemm_lut_{4}<1>(A, LUT, Scales, LUT_Scales, C);\n\ - }} else if (bs == 8) {{\n\ - two_qgemm_lut_{4}<8>(A, LUT, Scales, LUT_Scales, C);\n\ - }} else if (bs == 32) {{\n\ - two_qgemm_lut_{4}<32>(A, LUT, Scales, LUT_Scales, C);\n\ - }} else if (bs == 128) {{\n\ - two_qgemm_lut_{4}<128>(A, LUT, Scales, LUT_Scales, C);\n\ - }} else if (bs == 256) {{\n\ - two_qgemm_lut_{4}<256>(A, LUT, Scales, LUT_Scales, C);\n\ - }} else if (bs == 512) {{\n\ - two_qgemm_lut_{4}<512>(A, LUT, Scales, LUT_Scales, C);\n\ - }}\n\ - }}\n\ - else if (BK == {3}) {{\n\ - if (bs == 1) {{\n\ - three_qgemm_lut_{4}<1>(A, sign, LUT, Scales, LUT_Scales, C);\n\ - }}else if (bs == 8) {{\n\ - three_qgemm_lut_{4}<8>(A, sign, LUT, Scales, LUT_Scales, C);\n\ - }}else if (bs == 32) {{\n\ - three_qgemm_lut_{4}<32>(A, sign, LUT, Scales, LUT_Scales, C);\n\ - }}else if (bs == 128) {{\n\ - three_qgemm_lut_{4}<128>(A, sign, LUT, Scales, LUT_Scales, C);\n\ - }}else if (bs == 256) {{\n\ - three_qgemm_lut_{4}<256>(A, sign, LUT, Scales, LUT_Scales, C);\n\ - }}else if (bs == 512) {{\n\ - three_qgemm_lut_{4}<512>(A, sign, LUT, Scales, LUT_Scales, C);\n\ - }}\n\ - }}\n\ - }}\n\ -".format(kernel_shapes[i][0], kernel_shapes[i][1], k_list[i][0], k_list[i][1], "{}_{}".format(kernel_shapes[i][0], kernel_shapes[i][1]))]) - kernel_code = "".join([kernel_code, "}\n"]) + template = env.get_template("tl2_top_api.h") + kernel_code = "\n" + template.render(kernel_shapes=kernel_shapes, k_list=k_list) + "\n" return kernel_code def gen_transform_code(kernel_shapes): diff --git a/utils/templates/tl2_top_api.h b/utils/templates/tl2_top_api.h new file mode 100644 index 00000000..681546bd --- /dev/null +++ b/utils/templates/tl2_top_api.h @@ -0,0 +1,49 @@ +void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT) { + partial_max_reset(bs, (&(((float*)LUT_Scales)[0]))); +{% for kernel_shape in kernel_shapes %} + {% if loop.index0 > 0 %}else {% endif %}if (m == {{ kernel_shapes[loop.index0][0] }} && two_k == {{ k_list[loop.index0][0] }} && three_k == {{ k_list[loop.index0][1] }}) { + for (int32_t b = 0; b < bs; b++) { + per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)]))); + three_lut_ctor<{{ k_list[loop.index0][1] }}>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b]))); + two_lut_ctor<{{ k_list[loop.index0][0] }}>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + {{ k_list[loop.index0][1] }}])), (&(((float*)LUT_Scales)[b]))); + } + } +{% endfor %} +} + +void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { +{% for kernel_shape in kernel_shapes %} + {% if loop.index0 > 0 %}else {% endif %}if (m == {{ kernel_shapes[loop.index0][0] }} && k == {{ kernel_shapes[loop.index0][1] }}) { + if (BK == {{ k_list[loop.index0][0] }}) { + if (bs == 1) { + two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<1>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 8) { + two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<8>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 32) { + two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<32>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 128) { + two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<128>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 256) { + two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<256>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 512) { + two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<512>(A, LUT, Scales, LUT_Scales, C); + } + } + else if (BK == {{ k_list[loop.index0][1] }}) { + if (bs == 1) { + three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<1>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 8) { + three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<8>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 32) { + three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<32>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 128) { + three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<128>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 256) { + three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<256>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 512) { + three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<512>(A, sign, LUT, Scales, LUT_Scales, C); + } + } + } +{% endfor %} +} From e8b47ae9246e0af5d40839078445390f3f5279fe Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Fri, 25 Oct 2024 23:19:12 +0200 Subject: [PATCH 08/19] Add tl2 gen transform template --- utils/codegen_tl2.py | 52 ++++------------------------- utils/templates/tl2_gen_transform.h | 37 ++++++++++++++++++++ 2 files changed, 43 insertions(+), 46 deletions(-) create mode 100644 utils/templates/tl2_gen_transform.h diff --git a/utils/codegen_tl2.py b/utils/codegen_tl2.py index 3341348f..06d5cc36 100644 --- a/utils/codegen_tl2.py +++ b/utils/codegen_tl2.py @@ -24,52 +24,12 @@ def gen_top_api(kernel_shapes, k_list): return kernel_code def gen_transform_code(kernel_shapes): - kernel_code = "\n\ -void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {\n\ - if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {\n\ - return;\n\ - }\n\ -\n\ - int k = tensor->ne[0];\n\ - int m = tensor->ne[1];\n\ - const int lut_scales_size = 1;\n\ - int bk = 0;\n\ - int bm = 0;\n" - - kernel_code = "".join([kernel_code, "\n\ - if (m == {0} && k == {1}) {{\n\ - bm = BM{0}_{1};\n\ - bk = BBK{0}_{1};\n\ - }}\n".format(kernel_shapes[0][0], kernel_shapes[0][1])]) - - for i in range(1, len(kernel_shapes)): - kernel_code = "".join([kernel_code, "else if (m == {0} && k == {1}) {{\n\ - bm = BM{0}_{1};\n\ - bk = BBK{0}_{1};\n\ - }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])]) - - kernel_code = "".join([kernel_code, "\n\ - const int n_tile_num = m / bm;\n\ - const int BK = bk;\n\ - uint8_t * qweights;\n\ - bitnet_float_type * scales;\n\ -\n\ - scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));\n\ - qweights = (uint8_t *) tensor->data;\n\ - int nbytes = (k - 256) * m / 3 * 5 / 8 + 256 * m / 2 * 4 / 8;\n\ - if (nbytes % 32 != 0) nbytes = 32 - nbytes % 32 + nbytes;\n\ - float * i2_scales = (float * )(qweights + nbytes);\n\ - scales[0] = (bitnet_float_type) i2_scales[0];\n\ -\n\ - tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;\n\ - bitnet_tensor_extras[bitnet_tensor_extras_index++] = {\n\ - /* .lut_scales_size = */ lut_scales_size,\n\ - /* .BK = */ BK,\n\ - /* .n_tile_num = */ n_tile_num,\n\ - /* .qweights = */ qweights,\n\ - /* .scales = */ scales\n\ - };\n\ -}\n"]) + env = Environment( + loader=FileSystemLoader(Path(__file__).parent / "templates"), + ) + + template = env.get_template("tl2_gen_transform.h") + kernel_code = "\n" + template.render(kernel_shapes=kernel_shapes) + "\n" return kernel_code diff --git a/utils/templates/tl2_gen_transform.h b/utils/templates/tl2_gen_transform.h new file mode 100644 index 00000000..d7159620 --- /dev/null +++ b/utils/templates/tl2_gen_transform.h @@ -0,0 +1,37 @@ +void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) { + if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { + return; + } + + int k = tensor->ne[0]; + int m = tensor->ne[1]; + const int lut_scales_size = 1; + int bk = 0; + int bm = 0; + {% for kernel_shape in kernel_shapes %} + {% if loop.index0 > 0 %}else {% endif %}if (m == {{ kernel_shapes[loop.index0][0] }} && k == {{ kernel_shapes[loop.index0][1] }}) { + bm = BM{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}; + bk = BBK{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}; + } + {% endfor %} + const int n_tile_num = m / bm; + const int BK = bk; + uint8_t * qweights; + bitnet_float_type * scales; + + scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type)); + qweights = (uint8_t *) tensor->data; + int nbytes = (k - 256) * m / 3 * 5 / 8 + 256 * m / 2 * 4 / 8; + if (nbytes % 32 != 0) nbytes = 32 - nbytes % 32 + nbytes; + float * i2_scales = (float * )(qweights + nbytes); + scales[0] = (bitnet_float_type) i2_scales[0]; + + tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index; + bitnet_tensor_extras[bitnet_tensor_extras_index++] = { + /* .lut_scales_size = */ lut_scales_size, + /* .BK = */ BK, + /* .n_tile_num = */ n_tile_num, + /* .qweights = */ qweights, + /* .scales = */ scales + }; +} From da9d961132b68a8581148c95c89760b0c5fb478c Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Fri, 25 Oct 2024 23:28:10 +0200 Subject: [PATCH 09/19] Formatting --- utils/codegen_tl2.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/utils/codegen_tl2.py b/utils/codegen_tl2.py index 06d5cc36..da0ebf7c 100644 --- a/utils/codegen_tl2.py +++ b/utils/codegen_tl2.py @@ -5,33 +5,35 @@ from jinja2 import Environment, FileSystemLoader def gen_ctor_code(): - return "\n" + (Path(__file__).parent / "templates" / "tl2_ctor.h").read_text(encoding='utf-8') + env = Environment( + loader=FileSystemLoader(Path(__file__).parent / "templates"), + ) + template = env.get_template("tl2_ctor.h") + return "\n" + template.render() def gen_tbl_impl(pre, BM, BK, bm, k_list): env = Environment( - loader=FileSystemLoader(Path(__file__).parent / "templates"), - ) + loader=FileSystemLoader(Path(__file__).parent / "templates"), + ) template = env.get_template("tl2_table_impl.h") return "\n" + template.render(pre=pre, BM=BM, BK=BK, bm=bm, k_list=k_list) def gen_top_api(kernel_shapes, k_list): env = Environment( - loader=FileSystemLoader(Path(__file__).parent / "templates"), - ) + loader=FileSystemLoader(Path(__file__).parent / "templates"), + ) template = env.get_template("tl2_top_api.h") - kernel_code = "\n" + template.render(kernel_shapes=kernel_shapes, k_list=k_list) + "\n" - return kernel_code + return "\n" + template.render(kernel_shapes=kernel_shapes, k_list=k_list) + "\n" def gen_transform_code(kernel_shapes): env = Environment( - loader=FileSystemLoader(Path(__file__).parent / "templates"), - ) + loader=FileSystemLoader(Path(__file__).parent / "templates"), + ) template = env.get_template("tl2_gen_transform.h") - kernel_code = "\n" + template.render(kernel_shapes=kernel_shapes) + "\n" + return "\n" + template.render(kernel_shapes=kernel_shapes) + "\n" - return kernel_code def get_three_k_two_k(K, bk): bk_num = K // bk From 005ec236085aa6858c1e5fe38ac3baec40fb7fed Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Fri, 25 Oct 2024 23:56:09 +0200 Subject: [PATCH 10/19] Better impl for table --- utils/codegen_tl2.py | 14 ++++---------- utils/templates/tl2_table_impl.h | 15 ++++++++++----- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/utils/codegen_tl2.py b/utils/codegen_tl2.py index da0ebf7c..204e08ff 100644 --- a/utils/codegen_tl2.py +++ b/utils/codegen_tl2.py @@ -11,12 +11,12 @@ def gen_ctor_code(): template = env.get_template("tl2_ctor.h") return "\n" + template.render() -def gen_tbl_impl(pre, BM, BK, bm, k_list): +def gen_tbl_impl(kernel_shapes, BM_list, BK_list, bm_list, k_list): env = Environment( loader=FileSystemLoader(Path(__file__).parent / "templates"), ) template = env.get_template("tl2_table_impl.h") - return "\n" + template.render(pre=pre, BM=BM, BK=BK, bm=bm, k_list=k_list) + return "\n" + template.render(kernel_shapes=kernel_shapes, BM_list=BM_list, BK_list=BK_list, bm_list=bm_list, k_list=k_list) def gen_top_api(kernel_shapes, k_list): env = Environment( @@ -72,17 +72,11 @@ def get_three_k_two_k(K, bk): BK_list = [int(item) for item in args.BK.split(',')] bm_list = [int(item) for item in args.bm.split(',')] - tbl_impl_code = [] k_list = [] for i in range(len(kernel_shapes)): k_list.append(get_three_k_two_k(kernel_shapes[i][1], BK_list[i])) - for i in range(len(kernel_shapes)): - tbl_impl_code.append( - gen_tbl_impl("{}_{}".format(kernel_shapes[i][0], kernel_shapes[i][1]), BM_list[i], BK_list[i], bm_list[i], k_list[i]) - ) - assert(len(BM_list) == len(BK_list) == len(bm_list) == len(kernel_shapes)), "number of BM / BK / bm shoud be {}".format(len(kernel_shapes)) for i in range(len(kernel_shapes)): @@ -93,14 +87,14 @@ def get_three_k_two_k(K, bk): ctor_code = gen_ctor_code() api_code = gen_top_api(kernel_shapes, k_list) trans_code = gen_transform_code(kernel_shapes) + tbl_impl_code = gen_tbl_impl(kernel_shapes, BM_list, BK_list, bm_list, k_list) output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "include") with open(''.join([output_dir, "/bitnet-lut-kernels.h"]), 'w') as f: f.write(''.join("#if defined(GGML_BITNET_X86_TL2)")) f.write(''.join(ctor_code)) - for code in tbl_impl_code: - f.write(''.join(code)) + f.write(''.join(tbl_impl_code)) f.write(''.join(api_code)) f.write(''.join(trans_code)) f.write(''.join("#endif")) diff --git a/utils/templates/tl2_table_impl.h b/utils/templates/tl2_table_impl.h index 624b4bc9..1b044a74 100644 --- a/utils/templates/tl2_table_impl.h +++ b/utils/templates/tl2_table_impl.h @@ -1,11 +1,15 @@ -#include +{% for kernel_shape in kernel_shapes %} +{% set pre = kernel_shape[0] ~ "_" ~ kernel_shape[1] %} +{% set BM = BM_list[loop.index0] %} +{% set BK = BK_list[loop.index0] %} +{% set bm = bm_list[loop.index0] %} +{% set k_list = k_list[loop.index0] %} + +static constexpr auto BM{{ pre }} = {{ BM }}; +static constexpr auto BBK{{ pre }} = {{ BK }}; -#define BM{{ pre }} {{ BM }} -#define BBK{{ pre }} {{ BK }} template inline void three_tbl_impl_{{ pre }}(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) { - - #ifdef __AVX2__ const __m256i vec_mask = _mm256_set1_epi8(0x0f); const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000); @@ -243,3 +247,4 @@ int32_t two_qgemm_lut_{{ pre }}(void* A, void* LUT, void* Scales, void* LUT_Scal } return 0; } +{% endfor %} From 4a34de23c5cc4f527c4f6dbaf520ca13e08f7d88 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Fri, 25 Oct 2024 23:59:29 +0200 Subject: [PATCH 11/19] Fix name collision --- utils/templates/tl2_table_impl.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/utils/templates/tl2_table_impl.h b/utils/templates/tl2_table_impl.h index 1b044a74..3782bc90 100644 --- a/utils/templates/tl2_table_impl.h +++ b/utils/templates/tl2_table_impl.h @@ -3,7 +3,7 @@ {% set BM = BM_list[loop.index0] %} {% set BK = BK_list[loop.index0] %} {% set bm = bm_list[loop.index0] %} -{% set k_list = k_list[loop.index0] %} +{% set k_list_indexed = k_list[loop.index0] %} static constexpr auto BM{{ pre }} = {{ BM }}; static constexpr auto BBK{{ pre }} = {{ BK }}; @@ -216,8 +216,8 @@ int32_t three_qgemm_lut_{{ pre }}(void* A, void* sign, void* LUT, void* Scales, alignas(32) uint32_t CBits[BATCH_SIZE * BM{{ pre }}]; memset(&(CBits[0]), 0, BATCH_SIZE * BM{{ pre }} * sizeof(int32_t)); #pragma unroll - for (int32_t k_outer = 0; k_outer < {{ k_list[1] }} / BBK{{ pre }}; ++k_outer) { - three_tbl_impl_{{ pre }}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{{ pre }} / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{{ pre }} / 3 / 2 * BM{{ pre }})])), (&(((uint8_t*)sign)[(k_outer * BBK{{ pre }} / 3 / 8 * BM{{ pre }})]))); + for (int32_t k_outer = 0; k_outer < {{ k_list_indexed[1] }} / BBK{{ pre }}; ++k_outer) { + three_tbl_impl_{{ pre }}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{{ pre }} / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{{ pre }} / 3 / 2 * BM{{ pre }})])), (&(((uint8_t*)sign)[(k_outer * BBK{{ pre }} / 3 / 8 * BM{{ pre }})]))); } #pragma unroll for (int bs = 0; bs < BATCH_SIZE; bs++) { @@ -234,8 +234,8 @@ int32_t two_qgemm_lut_{{ pre }}(void* A, void* LUT, void* Scales, void* LUT_Scal alignas(32) uint32_t CBits[BATCH_SIZE * BM{{ pre }}]; memset(&(CBits[0]), 0, BATCH_SIZE * BM{{ pre }} * sizeof(int32_t)); #pragma unroll - for (int32_t k_outer = 0; k_outer < {{ k_list[0] }} / 32; ++k_outer) { - two_tbl_impl{{ pre }}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM{{ pre }})]))); + for (int32_t k_outer = 0; k_outer < {{ k_list_indexed[0] }} / 32; ++k_outer) { + two_tbl_impl{{ pre }}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM{{ pre }})]))); } #pragma unroll for (int bs = 0; bs < BATCH_SIZE; bs++) { From 22fcf0f2ec3a5a19918922c64aa15445550d79ea Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Sat, 26 Oct 2024 00:09:23 +0200 Subject: [PATCH 12/19] Use one template for the whole tl2 codegen --- requirements.txt | 3 +- utils/codegen_tl2.py | 63 ++--- utils/templates/tl2.h | 607 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 626 insertions(+), 47 deletions(-) create mode 100644 utils/templates/tl2.h diff --git a/requirements.txt b/requirements.txt index 3f5c5472..bdff4394 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ -r 3rdparty/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt -r 3rdparty/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt -r 3rdparty/llama.cpp/requirements/requirements-convert_llama_ggml_to_gguf.txt --r 3rdparty/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt \ No newline at end of file +-r 3rdparty/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +jinja2 diff --git a/utils/codegen_tl2.py b/utils/codegen_tl2.py index 204e08ff..e56a25da 100644 --- a/utils/codegen_tl2.py +++ b/utils/codegen_tl2.py @@ -1,40 +1,8 @@ import argparse -import os from pathlib import Path from configparser import ConfigParser from jinja2 import Environment, FileSystemLoader -def gen_ctor_code(): - env = Environment( - loader=FileSystemLoader(Path(__file__).parent / "templates"), - ) - template = env.get_template("tl2_ctor.h") - return "\n" + template.render() - -def gen_tbl_impl(kernel_shapes, BM_list, BK_list, bm_list, k_list): - env = Environment( - loader=FileSystemLoader(Path(__file__).parent / "templates"), - ) - template = env.get_template("tl2_table_impl.h") - return "\n" + template.render(kernel_shapes=kernel_shapes, BM_list=BM_list, BK_list=BK_list, bm_list=bm_list, k_list=k_list) - -def gen_top_api(kernel_shapes, k_list): - env = Environment( - loader=FileSystemLoader(Path(__file__).parent / "templates"), - ) - - template = env.get_template("tl2_top_api.h") - return "\n" + template.render(kernel_shapes=kernel_shapes, k_list=k_list) + "\n" - -def gen_transform_code(kernel_shapes): - env = Environment( - loader=FileSystemLoader(Path(__file__).parent / "templates"), - ) - - template = env.get_template("tl2_gen_transform.h") - return "\n" + template.render(kernel_shapes=kernel_shapes) + "\n" - - def get_three_k_two_k(K, bk): bk_num = K // bk three_k = bk_num * bk @@ -84,20 +52,22 @@ def get_three_k_two_k(K, bk): assert (kernel_shapes[i][1] % BK_list[i]) % 32 == 0, "K %% BK %% 32 should be 0" assert bm_list[i] in [32], "choose bm from [32]" - ctor_code = gen_ctor_code() - api_code = gen_top_api(kernel_shapes, k_list) - trans_code = gen_transform_code(kernel_shapes) - tbl_impl_code = gen_tbl_impl(kernel_shapes, BM_list, BK_list, bm_list, k_list) + env = Environment( + loader=FileSystemLoader(Path(__file__).parent / "templates"), + ) + template = env.get_template("tl2.h") + contents = template.render( + kernel_shapes=kernel_shapes, + k_list=k_list, + BM_list=BM_list, + BK_list=BK_list, + bm_list=bm_list, + ) - output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "include") + output_dir = Path(__file__).resolve().parent.parent / "include" - with open(''.join([output_dir, "/bitnet-lut-kernels.h"]), 'w') as f: - f.write(''.join("#if defined(GGML_BITNET_X86_TL2)")) - f.write(''.join(ctor_code)) - f.write(''.join(tbl_impl_code)) - f.write(''.join(api_code)) - f.write(''.join(trans_code)) - f.write(''.join("#endif")) + header_file = output_dir / "bitnet-lut-kernels.h" + header_file.write_text(contents, encoding='utf-8') config = ConfigParser() @@ -109,5 +79,6 @@ def get_three_k_two_k(K, bk): config.set('Kernels_{}'.format(i), 'BK'.format(i), str(BK_list[i])) config.set('Kernels_{}'.format(i), 'bmm'.format(i), str(bm_list[i])) - with open(''.join([output_dir, "/kernel_config.ini"]), 'w') as configfile: - config.write(configfile) \ No newline at end of file + config_file = output_dir / "kernel_config.ini" + with open(config_file, 'w', encoding='utf-8') as configfile: + config.write(configfile) diff --git a/utils/templates/tl2.h b/utils/templates/tl2.h new file mode 100644 index 00000000..e75c8457 --- /dev/null +++ b/utils/templates/tl2.h @@ -0,0 +1,607 @@ +#if defined(GGML_BITNET_X86_TL2) +#include "ggml-bitnet.h" +#include +#include +#define GGML_BITNET_MAX_NODES 8192 +static bool initialized = false; +static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; +static size_t bitnet_tensor_extras_index = 0; +static void * aligned_malloc(size_t size) { +#if defined(_WIN32) + return _aligned_malloc(size, 64); +#else + void * ptr = nullptr; + posix_memalign(&ptr, 64, size); + return ptr; +#endif +} + +static void aligned_free(void * ptr) { +#if defined(_WIN32) + _aligned_free(ptr); +#else + free(ptr); +#endif +} +#define BK2 32 +#if defined __AVX2__ +inline void _mm256_merge_epi32(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) +{ + __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0)); + *vl = _mm256_unpacklo_epi32(va, vb); + *vh = _mm256_unpackhi_epi32(va, vb); +} +inline void _mm256_merge_epi64(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) +{ + __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0)); + *vl = _mm256_unpacklo_epi64(va, vb); + *vh = _mm256_unpackhi_epi64(va, vb); +} +inline void _mm256_merge_si128(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) +{ + *vl = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 2, 0, 0)); + *vh = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 3, 0, 1)); +} +inline void Transpose_8_8( + __m256i *v0, + __m256i *v1, + __m256i *v2, + __m256i *v3, + __m256i *v4, + __m256i *v5, + __m256i *v6, + __m256i *v7) +{ + __m256i w0, w1, w2, w3, w4, w5, w6, w7; + __m256i x0, x1, x2, x3, x4, x5, x6, x7; + _mm256_merge_epi32(*v0, *v1, &w0, &w1); + _mm256_merge_epi32(*v2, *v3, &w2, &w3); + _mm256_merge_epi32(*v4, *v5, &w4, &w5); + _mm256_merge_epi32(*v6, *v7, &w6, &w7); + _mm256_merge_epi64(w0, w2, &x0, &x1); + _mm256_merge_epi64(w1, w3, &x2, &x3); + _mm256_merge_epi64(w4, w6, &x4, &x5); + _mm256_merge_epi64(w5, w7, &x6, &x7); + _mm256_merge_si128(x0, x4, v0, v1); + _mm256_merge_si128(x1, x5, v2, v3); + _mm256_merge_si128(x2, x6, v4, v5); + _mm256_merge_si128(x3, x7, v6, v7); +} +#endif +inline int32_t per_tensor_quant(int k, void* lut_scales_, void* b_) { + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + bitnet_float_type* b = (bitnet_float_type*)b_; +#if defined __AVX2__ + __m256 max_vec = _mm256_set1_ps(0.f); + const __m256 vec_sign = _mm256_set1_ps(-0.0f); + for (int i = 0; i < k / 8; i++) { + __m256 vec_b = _mm256_loadu_ps(b + i * 8); + __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b); + max_vec = _mm256_max_ps(vec_babs, max_vec); + } + __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec)); + max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1)); + max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1)); + float scales = 127 / _mm_cvtss_f32(max1); + *lut_scales = scales; +#endif + return 0; +} +inline int32_t partial_max_reset(int32_t bs, void* lut_scales_) { + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + #pragma unroll + for (int i=0; i< bs; i++) { + lut_scales[i] = 0.0; + } + return 0; +} +template +inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) { +#if defined __AVX2__ + __m256i vec_lut[16]; + const __m256i vec_bi = _mm256_set_epi32(84, 72, 60, 48, 36, 24, 12, 0); + float scales = *lut_scales; + __m256i shuffle_mask = _mm256_set_epi8( + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00, + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00 + ); +#pragma unroll + for (int k = 0; k < act_k / 24; ++k) { + __m256 vec_b0 = _mm256_i32gather_ps(b + k * 24 + 0, vec_bi, 1); + __m256 vec_b1 = _mm256_i32gather_ps(b + k * 24 + 1, vec_bi, 1); + __m256 vec_b2 = _mm256_i32gather_ps(b + k * 24 + 2, vec_bi, 1); + + __m256i vec_b0i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i vec_b1i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i vec_b2i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b2, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + vec_lut[15] = _mm256_setzero_si256(); + vec_lut[14] = _mm256_setzero_si256(); + vec_lut[13] = vec_b0i; + vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b1i); + vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b2i); + vec_lut[12] = vec_b0i; + vec_lut[12] = _mm256_add_epi32(vec_lut[12], vec_b1i); + vec_lut[11] = vec_b0i; + vec_lut[11] = _mm256_add_epi32(vec_lut[11], vec_b1i); + vec_lut[11] = _mm256_sub_epi32(vec_lut[11], vec_b2i); + vec_lut[10] = vec_b0i; + vec_lut[10] = _mm256_add_epi32(vec_lut[10], vec_b2i); + vec_lut[9] = vec_b0i; + vec_lut[8] = vec_b0i; + vec_lut[8] = _mm256_sub_epi32(vec_lut[8], vec_b2i); + vec_lut[7] = vec_b0i; + vec_lut[7] = _mm256_sub_epi32(vec_lut[7], vec_b1i); + vec_lut[7] = _mm256_add_epi32(vec_lut[7], vec_b2i); + vec_lut[6] = vec_b0i; + vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1i); + vec_lut[5] = vec_b0i; + vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b1i); + vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b2i); + vec_lut[4] = vec_b1i; + vec_lut[4] = _mm256_add_epi32(vec_lut[4], vec_b2i); + vec_lut[3] = vec_b1i; + vec_lut[2] = vec_b1i; + vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b2i); + vec_lut[1] = vec_b2i; + vec_lut[0] = _mm256_setzero_si256(); + __m256i ix[16]; + +#pragma unroll + for (int g = 0; g < 16; ++g) { + ix[g] = vec_lut[g]; + } + + Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7])); + Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15])); + +#pragma unroll + for (int g = 0; g < 8; ++g) { + ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + } + int8_t* qlut_i8 = reinterpret_cast(qlut); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]); + + } + + *lut_scales = scales; +#endif + return 0; +} + +template +inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) { +#if defined __AVX2__ + __m256i vec_lut[16]; + const __m256i vec_bi = _mm256_set_epi32(56, 48, 40, 32, 24, 16, 8, 0); + float scales = *lut_scales; + __m256i shuffle_mask = _mm256_set_epi8( + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00, + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00 + ); +#pragma unroll + for (int k = 0; k < act_k / 16; ++k) { + __m256 vec_b0f = _mm256_i32gather_ps(b + k * 16 + 0, vec_bi, 1); + __m256 vec_b1f = _mm256_i32gather_ps(b + k * 16 + 1, vec_bi, 1); + + __m256i vec_b0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i vec_b1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + vec_lut[15] = _mm256_setzero_si256(); + vec_lut[14] = _mm256_setzero_si256(); + vec_lut[13] = _mm256_setzero_si256(); + vec_lut[12] = _mm256_setzero_si256(); + vec_lut[11] = _mm256_setzero_si256(); + vec_lut[10] = _mm256_setzero_si256(); + vec_lut[9] = _mm256_setzero_si256(); + vec_lut[8] = vec_b0; + vec_lut[8] = _mm256_add_epi32(vec_lut[8], vec_b1); + vec_lut[7] = vec_b0; + vec_lut[6] = vec_b0; + vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1); + vec_lut[5] = vec_b1; + vec_lut[4] = _mm256_setzero_si256(); + vec_lut[3] = _mm256_setzero_si256(); + vec_lut[3] = _mm256_sub_epi32(vec_lut[3], vec_b1); + vec_lut[2] = _mm256_setzero_si256(); + vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b0); + vec_lut[2] = _mm256_add_epi32(vec_lut[2], vec_b1); + vec_lut[1] = _mm256_setzero_si256(); + vec_lut[1] = _mm256_sub_epi32(vec_lut[1], vec_b0); + vec_lut[0] = _mm256_setzero_si256(); + vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b0); + vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b1); + + __m256i ix[16]; +#pragma unroll + for (int g = 0; g < 16; ++g) { + ix[g] = vec_lut[g]; + } + + Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7])); + Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15])); + +#pragma unroll + for (int g = 0; g < 8; ++g) { + ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + } + + int8_t* qlut_i8 = reinterpret_cast(qlut); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]); + + } + *lut_scales = scales; +#endif + return 0; +} +static bool is_type_supported(enum ggml_type type) { + if (type == GGML_TYPE_Q4_0 || + type == GGML_TYPE_TL2) { + return true; + } else { + return false; + } +} +{% for kernel_shape in kernel_shapes %} +{% set pre = kernel_shape[0] ~ "_" ~ kernel_shape[1] %} +{% set BM = BM_list[loop.index0] %} +{% set BK = BK_list[loop.index0] %} +{% set bm = bm_list[loop.index0] %} +{% set k_list_indexed = k_list[loop.index0] %} + +static constexpr auto BM{{ pre }} = {{ BM }}; +static constexpr auto BBK{{ pre }} = {{ BK }}; + +template +inline void three_tbl_impl_{{ pre }}(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000); + const __m256i vec_zero = _mm256_set1_epi8(0x00); + const __m256i vec_one = _mm256_set1_epi8(0xff); + const int KK = BBK{{ pre }} / 3; +#pragma unroll + for (int i = 0; i < BM{{ pre }}; i += 32) { + __m256i vec_as[KK / 2]; + __m256i vec_signs[KK / 8]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } + #pragma unroll + for (int as = 0; as < KK / 8; as++) { + vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + __m256i vec_sign = vec_signs[k]; + __m256i vec_a_0 = vec_as[k * 4 + 0]; + __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15); + __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15); + __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask); + __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0); + __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0); + __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15); + __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15); + __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask); + __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0); + __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0); + __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0); + __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0); + __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0); + __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0); + __m256i vec_a_1 = vec_as[k * 4 + 1]; + __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15); + __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15); + __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask); + __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1); + __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1); + __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15); + __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15); + __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask); + __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1); + __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1); + __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1); + __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1); + __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1); + __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1); + __m256i vec_a_2 = vec_as[k * 4 + 2]; + __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15); + __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15); + __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask); + __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2); + __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2); + __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15); + __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15); + __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask); + __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2); + __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2); + __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2); + __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2); + __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2); + __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2); + __m256i vec_a_3 = vec_as[k * 4 + 3]; + __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15); + __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15); + __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask); + __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3); + __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3); + __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15); + __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15); + __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask); + __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3); + __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3); + __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3); + __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3); + __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3); + __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3); + } + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs)); + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs), vec_gc3); + } + } +#endif +} + +template +inline int32_t two_tbl_impl{{ pre }}(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const int KK = BK2 / 2; +#pragma unroll + for (int i = 0; i < BM{{ pre }}; i += 32) { + __m256i vec_as[KK / 2]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + __m256i vec_a = vec_as[k * 4 + j]; + + __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs)); + __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs)); + __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs)); + __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs)); + + __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask); + __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top); + __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top); + + __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask); + __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot); + __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot); + + __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec); + __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); + } + } + + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs)); + + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs), vec_gc3); + } + } +#endif + return 0; +} + +template +int32_t three_qgemm_lut_{{ pre }}(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM{{ pre }}]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM{{ pre }} * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < {{ k_list_indexed[1] }} / BBK{{ pre }}; ++k_outer) { + three_tbl_impl_{{ pre }}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{{ pre }} / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{{ pre }} / 3 / 2 * BM{{ pre }})])), (&(((uint8_t*)sign)[(k_outer * BBK{{ pre }} / 3 / 8 * BM{{ pre }})]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM{{ pre }}; i++) { + ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM{{ pre }}]); + } + } + return 0; +} + +template +int32_t two_qgemm_lut_{{ pre }}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM{{ pre }}]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM{{ pre }} * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < {{ k_list_indexed[0] }} / 32; ++k_outer) { + two_tbl_impl{{ pre }}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM{{ pre }})]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM{{ pre }}; i++) { + ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM{{ pre }}]); + ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0]; + } + } + return 0; +} +{% endfor %} +void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT) { + partial_max_reset(bs, (&(((float*)LUT_Scales)[0]))); +{% for kernel_shape in kernel_shapes %} + {% if loop.index0 > 0 %}else {% endif %}if (m == {{ kernel_shapes[loop.index0][0] }} && two_k == {{ k_list[loop.index0][0] }} && three_k == {{ k_list[loop.index0][1] }}) { + for (int32_t b = 0; b < bs; b++) { + per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)]))); + three_lut_ctor<{{ k_list[loop.index0][1] }}>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b]))); + two_lut_ctor<{{ k_list[loop.index0][0] }}>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + {{ k_list[loop.index0][1] }}])), (&(((float*)LUT_Scales)[b]))); + } + } +{% endfor %} +} + +void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { +{% for kernel_shape in kernel_shapes %} + {% if loop.index0 > 0 %}else {% endif %}if (m == {{ kernel_shapes[loop.index0][0] }} && k == {{ kernel_shapes[loop.index0][1] }}) { + if (BK == {{ k_list[loop.index0][0] }}) { + if (bs == 1) { + two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<1>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 8) { + two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<8>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 32) { + two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<32>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 128) { + two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<128>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 256) { + two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<256>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 512) { + two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<512>(A, LUT, Scales, LUT_Scales, C); + } + } + else if (BK == {{ k_list[loop.index0][1] }}) { + if (bs == 1) { + three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<1>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 8) { + three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<8>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 32) { + three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<32>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 128) { + three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<128>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 256) { + three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<256>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 512) { + three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<512>(A, sign, LUT, Scales, LUT_Scales, C); + } + } + } +{% endfor %} +} +void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) { + if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { + return; + } + + int k = tensor->ne[0]; + int m = tensor->ne[1]; + const int lut_scales_size = 1; + int bk = 0; + int bm = 0; + {% for kernel_shape in kernel_shapes %} + {% if loop.index0 > 0 %}else {% endif %}if (m == {{ kernel_shapes[loop.index0][0] }} && k == {{ kernel_shapes[loop.index0][1] }}) { + bm = BM{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}; + bk = BBK{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}; + } + {% endfor %} + const int n_tile_num = m / bm; + const int BK = bk; + uint8_t * qweights; + bitnet_float_type * scales; + + scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type)); + qweights = (uint8_t *) tensor->data; + int nbytes = (k - 256) * m / 3 * 5 / 8 + 256 * m / 2 * 4 / 8; + if (nbytes % 32 != 0) nbytes = 32 - nbytes % 32 + nbytes; + float * i2_scales = (float * )(qweights + nbytes); + scales[0] = (bitnet_float_type) i2_scales[0]; + + tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index; + bitnet_tensor_extras[bitnet_tensor_extras_index++] = { + /* .lut_scales_size = */ lut_scales_size, + /* .BK = */ BK, + /* .n_tile_num = */ n_tile_num, + /* .qweights = */ qweights, + /* .scales = */ scales + }; +} +#endif From ce6ad078aecc4f431ffe92ae5a6f85398ca2fc58 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Sat, 26 Oct 2024 00:09:59 +0200 Subject: [PATCH 13/19] Remove unused templates --- utils/templates/tl2_ctor.h | 269 ---------------------------- utils/templates/tl2_gen_transform.h | 37 ---- utils/templates/tl2_table_impl.h | 250 -------------------------- utils/templates/tl2_top_api.h | 49 ----- 4 files changed, 605 deletions(-) delete mode 100644 utils/templates/tl2_ctor.h delete mode 100644 utils/templates/tl2_gen_transform.h delete mode 100644 utils/templates/tl2_table_impl.h delete mode 100644 utils/templates/tl2_top_api.h diff --git a/utils/templates/tl2_ctor.h b/utils/templates/tl2_ctor.h deleted file mode 100644 index c65b1abb..00000000 --- a/utils/templates/tl2_ctor.h +++ /dev/null @@ -1,269 +0,0 @@ -#include "ggml-bitnet.h" -#include -#include -#define GGML_BITNET_MAX_NODES 8192 -static bool initialized = false; -static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; -static size_t bitnet_tensor_extras_index = 0; -static void * aligned_malloc(size_t size) { -#if defined(_WIN32) - return _aligned_malloc(size, 64); -#else - void * ptr = nullptr; - posix_memalign(&ptr, 64, size); - return ptr; -#endif -} - -static void aligned_free(void * ptr) { -#if defined(_WIN32) - _aligned_free(ptr); -#else - free(ptr); -#endif -} -#define BK2 32 -#if defined __AVX2__ -inline void _mm256_merge_epi32(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) -{ - __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0)); - __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0)); - *vl = _mm256_unpacklo_epi32(va, vb); - *vh = _mm256_unpackhi_epi32(va, vb); -} -inline void _mm256_merge_epi64(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) -{ - __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0)); - __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0)); - *vl = _mm256_unpacklo_epi64(va, vb); - *vh = _mm256_unpackhi_epi64(va, vb); -} -inline void _mm256_merge_si128(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) -{ - *vl = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 2, 0, 0)); - *vh = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 3, 0, 1)); -} -inline void Transpose_8_8( - __m256i *v0, - __m256i *v1, - __m256i *v2, - __m256i *v3, - __m256i *v4, - __m256i *v5, - __m256i *v6, - __m256i *v7) -{ - __m256i w0, w1, w2, w3, w4, w5, w6, w7; - __m256i x0, x1, x2, x3, x4, x5, x6, x7; - _mm256_merge_epi32(*v0, *v1, &w0, &w1); - _mm256_merge_epi32(*v2, *v3, &w2, &w3); - _mm256_merge_epi32(*v4, *v5, &w4, &w5); - _mm256_merge_epi32(*v6, *v7, &w6, &w7); - _mm256_merge_epi64(w0, w2, &x0, &x1); - _mm256_merge_epi64(w1, w3, &x2, &x3); - _mm256_merge_epi64(w4, w6, &x4, &x5); - _mm256_merge_epi64(w5, w7, &x6, &x7); - _mm256_merge_si128(x0, x4, v0, v1); - _mm256_merge_si128(x1, x5, v2, v3); - _mm256_merge_si128(x2, x6, v4, v5); - _mm256_merge_si128(x3, x7, v6, v7); -} -#endif -inline int32_t per_tensor_quant(int k, void* lut_scales_, void* b_) { - bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; - bitnet_float_type* b = (bitnet_float_type*)b_; -#if defined __AVX2__ - __m256 max_vec = _mm256_set1_ps(0.f); - const __m256 vec_sign = _mm256_set1_ps(-0.0f); - for (int i = 0; i < k / 8; i++) { - __m256 vec_b = _mm256_loadu_ps(b + i * 8); - __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b); - max_vec = _mm256_max_ps(vec_babs, max_vec); - } - __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec)); - max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1)); - max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1)); - float scales = 127 / _mm_cvtss_f32(max1); - *lut_scales = scales; -#endif - return 0; -} -inline int32_t partial_max_reset(int32_t bs, void* lut_scales_) { - bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; - #pragma unroll - for (int i=0; i< bs; i++) { - lut_scales[i] = 0.0; - } - return 0; -} -template -inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) { -#if defined __AVX2__ - __m256i vec_lut[16]; - const __m256i vec_bi = _mm256_set_epi32(84, 72, 60, 48, 36, 24, 12, 0); - float scales = *lut_scales; - __m256i shuffle_mask = _mm256_set_epi8( - 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, - 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00, - 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, - 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00 - ); -#pragma unroll - for (int k = 0; k < act_k / 24; ++k) { - __m256 vec_b0 = _mm256_i32gather_ps(b + k * 24 + 0, vec_bi, 1); - __m256 vec_b1 = _mm256_i32gather_ps(b + k * 24 + 1, vec_bi, 1); - __m256 vec_b2 = _mm256_i32gather_ps(b + k * 24 + 2, vec_bi, 1); - - __m256i vec_b0i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - __m256i vec_b1i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - __m256i vec_b2i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b2, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - vec_lut[15] = _mm256_setzero_si256(); - vec_lut[14] = _mm256_setzero_si256(); - vec_lut[13] = vec_b0i; - vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b1i); - vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b2i); - vec_lut[12] = vec_b0i; - vec_lut[12] = _mm256_add_epi32(vec_lut[12], vec_b1i); - vec_lut[11] = vec_b0i; - vec_lut[11] = _mm256_add_epi32(vec_lut[11], vec_b1i); - vec_lut[11] = _mm256_sub_epi32(vec_lut[11], vec_b2i); - vec_lut[10] = vec_b0i; - vec_lut[10] = _mm256_add_epi32(vec_lut[10], vec_b2i); - vec_lut[9] = vec_b0i; - vec_lut[8] = vec_b0i; - vec_lut[8] = _mm256_sub_epi32(vec_lut[8], vec_b2i); - vec_lut[7] = vec_b0i; - vec_lut[7] = _mm256_sub_epi32(vec_lut[7], vec_b1i); - vec_lut[7] = _mm256_add_epi32(vec_lut[7], vec_b2i); - vec_lut[6] = vec_b0i; - vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1i); - vec_lut[5] = vec_b0i; - vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b1i); - vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b2i); - vec_lut[4] = vec_b1i; - vec_lut[4] = _mm256_add_epi32(vec_lut[4], vec_b2i); - vec_lut[3] = vec_b1i; - vec_lut[2] = vec_b1i; - vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b2i); - vec_lut[1] = vec_b2i; - vec_lut[0] = _mm256_setzero_si256(); - __m256i ix[16]; - -#pragma unroll - for (int g = 0; g < 16; ++g) { - ix[g] = vec_lut[g]; - } - - Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7])); - Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15])); - -#pragma unroll - for (int g = 0; g < 8; ++g) { - ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]); - ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); - ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask); - ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); - } - int8_t* qlut_i8 = reinterpret_cast(qlut); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]); - - } - - *lut_scales = scales; -#endif - return 0; -} - -template -inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) { -#if defined __AVX2__ - __m256i vec_lut[16]; - const __m256i vec_bi = _mm256_set_epi32(56, 48, 40, 32, 24, 16, 8, 0); - float scales = *lut_scales; - __m256i shuffle_mask = _mm256_set_epi8( - 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, - 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00, - 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, - 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00 - ); -#pragma unroll - for (int k = 0; k < act_k / 16; ++k) { - __m256 vec_b0f = _mm256_i32gather_ps(b + k * 16 + 0, vec_bi, 1); - __m256 vec_b1f = _mm256_i32gather_ps(b + k * 16 + 1, vec_bi, 1); - - __m256i vec_b0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - __m256i vec_b1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - vec_lut[15] = _mm256_setzero_si256(); - vec_lut[14] = _mm256_setzero_si256(); - vec_lut[13] = _mm256_setzero_si256(); - vec_lut[12] = _mm256_setzero_si256(); - vec_lut[11] = _mm256_setzero_si256(); - vec_lut[10] = _mm256_setzero_si256(); - vec_lut[9] = _mm256_setzero_si256(); - vec_lut[8] = vec_b0; - vec_lut[8] = _mm256_add_epi32(vec_lut[8], vec_b1); - vec_lut[7] = vec_b0; - vec_lut[6] = vec_b0; - vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1); - vec_lut[5] = vec_b1; - vec_lut[4] = _mm256_setzero_si256(); - vec_lut[3] = _mm256_setzero_si256(); - vec_lut[3] = _mm256_sub_epi32(vec_lut[3], vec_b1); - vec_lut[2] = _mm256_setzero_si256(); - vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b0); - vec_lut[2] = _mm256_add_epi32(vec_lut[2], vec_b1); - vec_lut[1] = _mm256_setzero_si256(); - vec_lut[1] = _mm256_sub_epi32(vec_lut[1], vec_b0); - vec_lut[0] = _mm256_setzero_si256(); - vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b0); - vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b1); - - __m256i ix[16]; -#pragma unroll - for (int g = 0; g < 16; ++g) { - ix[g] = vec_lut[g]; - } - - Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7])); - Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15])); - -#pragma unroll - for (int g = 0; g < 8; ++g) { - ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]); - ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); - ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask); - ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); - } - - int8_t* qlut_i8 = reinterpret_cast(qlut); - - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]); - - } - *lut_scales = scales; -#endif - return 0; -} -static bool is_type_supported(enum ggml_type type) { - if (type == GGML_TYPE_Q4_0 || - type == GGML_TYPE_TL2) { - return true; - } else { - return false; - } -} diff --git a/utils/templates/tl2_gen_transform.h b/utils/templates/tl2_gen_transform.h deleted file mode 100644 index d7159620..00000000 --- a/utils/templates/tl2_gen_transform.h +++ /dev/null @@ -1,37 +0,0 @@ -void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) { - if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { - return; - } - - int k = tensor->ne[0]; - int m = tensor->ne[1]; - const int lut_scales_size = 1; - int bk = 0; - int bm = 0; - {% for kernel_shape in kernel_shapes %} - {% if loop.index0 > 0 %}else {% endif %}if (m == {{ kernel_shapes[loop.index0][0] }} && k == {{ kernel_shapes[loop.index0][1] }}) { - bm = BM{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}; - bk = BBK{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}; - } - {% endfor %} - const int n_tile_num = m / bm; - const int BK = bk; - uint8_t * qweights; - bitnet_float_type * scales; - - scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type)); - qweights = (uint8_t *) tensor->data; - int nbytes = (k - 256) * m / 3 * 5 / 8 + 256 * m / 2 * 4 / 8; - if (nbytes % 32 != 0) nbytes = 32 - nbytes % 32 + nbytes; - float * i2_scales = (float * )(qweights + nbytes); - scales[0] = (bitnet_float_type) i2_scales[0]; - - tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index; - bitnet_tensor_extras[bitnet_tensor_extras_index++] = { - /* .lut_scales_size = */ lut_scales_size, - /* .BK = */ BK, - /* .n_tile_num = */ n_tile_num, - /* .qweights = */ qweights, - /* .scales = */ scales - }; -} diff --git a/utils/templates/tl2_table_impl.h b/utils/templates/tl2_table_impl.h deleted file mode 100644 index 3782bc90..00000000 --- a/utils/templates/tl2_table_impl.h +++ /dev/null @@ -1,250 +0,0 @@ -{% for kernel_shape in kernel_shapes %} -{% set pre = kernel_shape[0] ~ "_" ~ kernel_shape[1] %} -{% set BM = BM_list[loop.index0] %} -{% set BK = BK_list[loop.index0] %} -{% set bm = bm_list[loop.index0] %} -{% set k_list_indexed = k_list[loop.index0] %} - -static constexpr auto BM{{ pre }} = {{ BM }}; -static constexpr auto BBK{{ pre }} = {{ BK }}; - -template -inline void three_tbl_impl_{{ pre }}(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) { -#ifdef __AVX2__ - const __m256i vec_mask = _mm256_set1_epi8(0x0f); - const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000); - const __m256i vec_zero = _mm256_set1_epi8(0x00); - const __m256i vec_one = _mm256_set1_epi8(0xff); - const int KK = BBK{{ pre }} / 3; -#pragma unroll - for (int i = 0; i < BM{{ pre }}; i += 32) { - __m256i vec_as[KK / 2]; - __m256i vec_signs[KK / 8]; - #pragma unroll - for (int ai = 0; ai < KK / 2; ai++) { - vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); - } - #pragma unroll - for (int as = 0; as < KK / 8; as++) { - vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32)); - } -#pragma unroll - for (int bs = 0; bs < batch_size; bs++) { - __m256i vec_c0 = _mm256_setzero_si256(); - __m256i vec_c1 = _mm256_setzero_si256(); -#pragma unroll - for (int k = 0; k < KK / 8; k++) { - __m256i vec_sign = vec_signs[k]; - __m256i vec_a_0 = vec_as[k * 4 + 0]; - __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs)); - __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs)); - __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs)); - __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs)); - __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15); - __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15); - __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask); - __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0); - __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0); - __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15); - __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15); - __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask); - __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0); - __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0); - __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0); - __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0); - __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0); - __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0); - __m256i vec_a_1 = vec_as[k * 4 + 1]; - __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs)); - __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs)); - __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs)); - __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs)); - __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15); - __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15); - __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask); - __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1); - __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1); - __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15); - __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15); - __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask); - __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1); - __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1); - __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1); - __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1); - __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1); - __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1); - __m256i vec_a_2 = vec_as[k * 4 + 2]; - __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs)); - __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs)); - __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs)); - __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs)); - __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15); - __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15); - __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask); - __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2); - __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2); - __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15); - __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15); - __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask); - __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2); - __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2); - __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2); - __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2); - __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2); - __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2); - __m256i vec_a_3 = vec_as[k * 4 + 3]; - __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs)); - __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs)); - __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs)); - __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs)); - __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15); - __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15); - __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask); - __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3); - __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3); - __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15); - __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15); - __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask); - __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3); - __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3); - __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3); - __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3); - __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3); - __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3); - } - __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs)); - __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs)); - __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs)); - __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs)); - vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); - vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); - vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); - vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs), vec_gc0); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs), vec_gc1); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs), vec_gc2); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs), vec_gc3); - } - } -#endif -} - -template -inline int32_t two_tbl_impl{{ pre }}(int32_t* c, int8_t* lut, uint8_t* a) { -#ifdef __AVX2__ - const __m256i vec_mask = _mm256_set1_epi8(0x0f); - const int KK = BK2 / 2; -#pragma unroll - for (int i = 0; i < BM{{ pre }}; i += 32) { - __m256i vec_as[KK / 2]; - #pragma unroll - for (int ai = 0; ai < KK / 2; ai++) { - vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); - } -#pragma unroll - for (int bs = 0; bs < batch_size; bs++) { - __m256i vec_c0 = _mm256_setzero_si256(); - __m256i vec_c1 = _mm256_setzero_si256(); -#pragma unroll - for (int k = 0; k < KK / 8; k++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - __m256i vec_a = vec_as[k * 4 + j]; - - __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs)); - __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs)); - __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs)); - __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs)); - - __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask); - __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top); - __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top); - - __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask); - __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot); - __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot); - - __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec); - __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec); - __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec); - __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); - } - } - - __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs)); - __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs)); - __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs)); - __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs)); - - vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); - vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); - vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); - vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); - - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs), vec_gc0); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs), vec_gc1); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs), vec_gc2); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs), vec_gc3); - } - } -#endif - return 0; -} - -template -int32_t three_qgemm_lut_{{ pre }}(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { - alignas(32) uint32_t CBits[BATCH_SIZE * BM{{ pre }}]; - memset(&(CBits[0]), 0, BATCH_SIZE * BM{{ pre }} * sizeof(int32_t)); -#pragma unroll - for (int32_t k_outer = 0; k_outer < {{ k_list_indexed[1] }} / BBK{{ pre }}; ++k_outer) { - three_tbl_impl_{{ pre }}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{{ pre }} / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{{ pre }} / 3 / 2 * BM{{ pre }})])), (&(((uint8_t*)sign)[(k_outer * BBK{{ pre }} / 3 / 8 * BM{{ pre }})]))); - } -#pragma unroll - for (int bs = 0; bs < BATCH_SIZE; bs++) { -#pragma unroll - for (int i = 0; i < BM{{ pre }}; i++) { - ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM{{ pre }}]); - } - } - return 0; -} - -template -int32_t two_qgemm_lut_{{ pre }}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { - alignas(32) uint32_t CBits[BATCH_SIZE * BM{{ pre }}]; - memset(&(CBits[0]), 0, BATCH_SIZE * BM{{ pre }} * sizeof(int32_t)); -#pragma unroll - for (int32_t k_outer = 0; k_outer < {{ k_list_indexed[0] }} / 32; ++k_outer) { - two_tbl_impl{{ pre }}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM{{ pre }})]))); - } -#pragma unroll - for (int bs = 0; bs < BATCH_SIZE; bs++) { -#pragma unroll - for (int i = 0; i < BM{{ pre }}; i++) { - ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM{{ pre }}]); - ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0]; - } - } - return 0; -} -{% endfor %} diff --git a/utils/templates/tl2_top_api.h b/utils/templates/tl2_top_api.h deleted file mode 100644 index 681546bd..00000000 --- a/utils/templates/tl2_top_api.h +++ /dev/null @@ -1,49 +0,0 @@ -void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT) { - partial_max_reset(bs, (&(((float*)LUT_Scales)[0]))); -{% for kernel_shape in kernel_shapes %} - {% if loop.index0 > 0 %}else {% endif %}if (m == {{ kernel_shapes[loop.index0][0] }} && two_k == {{ k_list[loop.index0][0] }} && three_k == {{ k_list[loop.index0][1] }}) { - for (int32_t b = 0; b < bs; b++) { - per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)]))); - three_lut_ctor<{{ k_list[loop.index0][1] }}>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b]))); - two_lut_ctor<{{ k_list[loop.index0][0] }}>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + {{ k_list[loop.index0][1] }}])), (&(((float*)LUT_Scales)[b]))); - } - } -{% endfor %} -} - -void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { -{% for kernel_shape in kernel_shapes %} - {% if loop.index0 > 0 %}else {% endif %}if (m == {{ kernel_shapes[loop.index0][0] }} && k == {{ kernel_shapes[loop.index0][1] }}) { - if (BK == {{ k_list[loop.index0][0] }}) { - if (bs == 1) { - two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<1>(A, LUT, Scales, LUT_Scales, C); - } else if (bs == 8) { - two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<8>(A, LUT, Scales, LUT_Scales, C); - } else if (bs == 32) { - two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<32>(A, LUT, Scales, LUT_Scales, C); - } else if (bs == 128) { - two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<128>(A, LUT, Scales, LUT_Scales, C); - } else if (bs == 256) { - two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<256>(A, LUT, Scales, LUT_Scales, C); - } else if (bs == 512) { - two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<512>(A, LUT, Scales, LUT_Scales, C); - } - } - else if (BK == {{ k_list[loop.index0][1] }}) { - if (bs == 1) { - three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<1>(A, sign, LUT, Scales, LUT_Scales, C); - }else if (bs == 8) { - three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<8>(A, sign, LUT, Scales, LUT_Scales, C); - }else if (bs == 32) { - three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<32>(A, sign, LUT, Scales, LUT_Scales, C); - }else if (bs == 128) { - three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<128>(A, sign, LUT, Scales, LUT_Scales, C); - }else if (bs == 256) { - three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<256>(A, sign, LUT, Scales, LUT_Scales, C); - }else if (bs == 512) { - three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<512>(A, sign, LUT, Scales, LUT_Scales, C); - } - } - } -{% endfor %} -} From c8d72dd82c964a8408fe61a6b351cc7d89ca0848 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Sat, 26 Oct 2024 00:19:09 +0200 Subject: [PATCH 14/19] Use constexpr instead of a define --- utils/templates/tl2.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/templates/tl2.h b/utils/templates/tl2.h index e75c8457..d770a6de 100644 --- a/utils/templates/tl2.h +++ b/utils/templates/tl2.h @@ -2,7 +2,7 @@ #include "ggml-bitnet.h" #include #include -#define GGML_BITNET_MAX_NODES 8192 +static constexpr auto GGML_BITNET_MAX_NODES = 8192; static bool initialized = false; static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; static size_t bitnet_tensor_extras_index = 0; From 5b672a0725ff48711b70c01a2c89b005f22f6718 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Sat, 26 Oct 2024 00:47:11 +0200 Subject: [PATCH 15/19] Set partial loop unrolling for GCC Also fix indentation --- utils/templates/tl2.h | 359 ++++++++++++++++++++++-------------------- 1 file changed, 188 insertions(+), 171 deletions(-) diff --git a/utils/templates/tl2.h b/utils/templates/tl2.h index d770a6de..00fbd98b 100644 --- a/utils/templates/tl2.h +++ b/utils/templates/tl2.h @@ -2,6 +2,23 @@ #include "ggml-bitnet.h" #include #include + +/* Below macros are used for multi-compiler support for unrolling loops */ +#define TO_STRING_HELPER(X) #X +#define TO_STRING(X) TO_STRING_HELPER(X) + +#if defined(__clang__) + #define UNROLL_LOOP(n) _Pragma(TO_STRING(unroll)) +#elif defined(__GNUC__) && !defined(__clang__) + #define UNROLL_LOOP(n) _Pragma(TO_STRING(GCC unroll (n))) +#elif defined(_MSC_BUILD) + #pragma message ("Microsoft Visual C++ (MSVC) detected: Loop unrolling not supported!") + #define UNROLL_LOOP(n) +#else + #warning "Unknown compiler: Loop unrolling not supported!" + #define UNROLL_LOOP(n) +#endif + static constexpr auto GGML_BITNET_MAX_NODES = 8192; static bool initialized = false; static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; @@ -109,7 +126,7 @@ inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_t 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00 ); -#pragma unroll + #pragma unroll for (int k = 0; k < act_k / 24; ++k) { __m256 vec_b0 = _mm256_i32gather_ps(b + k * 24 + 0, vec_bi, 1); __m256 vec_b1 = _mm256_i32gather_ps(b + k * 24 + 1, vec_bi, 1); @@ -151,7 +168,7 @@ inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_t vec_lut[0] = _mm256_setzero_si256(); __m256i ix[16]; -#pragma unroll + UNROLL_LOOP(16) for (int g = 0; g < 16; ++g) { ix[g] = vec_lut[g]; } @@ -159,7 +176,7 @@ inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_t Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7])); Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15])); -#pragma unroll + UNROLL_LOOP(8) for (int g = 0; g < 8; ++g) { ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]); ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); @@ -195,7 +212,7 @@ inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_typ 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00 ); -#pragma unroll + #pragma unroll for (int k = 0; k < act_k / 16; ++k) { __m256 vec_b0f = _mm256_i32gather_ps(b + k * 16 + 0, vec_bi, 1); __m256 vec_b1f = _mm256_i32gather_ps(b + k * 16 + 1, vec_bi, 1); @@ -228,7 +245,7 @@ inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_typ vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b1); __m256i ix[16]; -#pragma unroll + UNROLL_LOOP(16) for (int g = 0; g < 16; ++g) { ix[g] = vec_lut[g]; } @@ -236,7 +253,7 @@ inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_typ Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7])); Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15])); -#pragma unroll + UNROLL_LOOP(8) for (int g = 0; g < 8; ++g) { ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]); ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); @@ -286,8 +303,8 @@ inline void three_tbl_impl_{{ pre }}(int32_t* c, int8_t* lut, uint8_t* a, uint8_ const __m256i vec_zero = _mm256_set1_epi8(0x00); const __m256i vec_one = _mm256_set1_epi8(0xff); const int KK = BBK{{ pre }} / 3; -#pragma unroll - for (int i = 0; i < BM{{ pre }}; i += 32) { + UNROLL_LOOP(BM{{ pre }}) + for (int i = 0; i < BM{{ pre }}; i += 32) { __m256i vec_as[KK / 2]; __m256i vec_signs[KK / 8]; #pragma unroll @@ -298,119 +315,119 @@ inline void three_tbl_impl_{{ pre }}(int32_t* c, int8_t* lut, uint8_t* a, uint8_ for (int as = 0; as < KK / 8; as++) { vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32)); } -#pragma unroll - for (int bs = 0; bs < batch_size; bs++) { - __m256i vec_c0 = _mm256_setzero_si256(); - __m256i vec_c1 = _mm256_setzero_si256(); -#pragma unroll - for (int k = 0; k < KK / 8; k++) { - __m256i vec_sign = vec_signs[k]; - __m256i vec_a_0 = vec_as[k * 4 + 0]; - __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs)); - __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs)); - __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs)); - __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs)); - __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15); - __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15); - __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask); - __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0); - __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0); - __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15); - __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15); - __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask); - __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0); - __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0); - __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0); - __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0); - __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0); - __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0); - __m256i vec_a_1 = vec_as[k * 4 + 1]; - __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs)); - __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs)); - __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs)); - __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs)); - __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15); - __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15); - __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask); - __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1); - __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1); - __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15); - __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15); - __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask); - __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1); - __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1); - __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1); - __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1); - __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1); - __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1); - __m256i vec_a_2 = vec_as[k * 4 + 2]; - __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs)); - __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs)); - __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs)); - __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs)); - __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15); - __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15); - __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask); - __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2); - __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2); - __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15); - __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15); - __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask); - __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2); - __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2); - __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2); - __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2); - __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2); - __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2); - __m256i vec_a_3 = vec_as[k * 4 + 3]; - __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs)); - __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs)); - __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs)); - __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs)); - __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15); - __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15); - __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask); - __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3); - __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3); - __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15); - __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15); - __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask); - __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3); - __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3); - __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3); - __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3); - __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3); - __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3); + #pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); + #pragma unroll + for (int k = 0; k < KK / 8; k++) { + __m256i vec_sign = vec_signs[k]; + __m256i vec_a_0 = vec_as[k * 4 + 0]; + __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15); + __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15); + __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask); + __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0); + __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0); + __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15); + __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15); + __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask); + __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0); + __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0); + __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0); + __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0); + __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0); + __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0); + __m256i vec_a_1 = vec_as[k * 4 + 1]; + __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15); + __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15); + __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask); + __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1); + __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1); + __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15); + __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15); + __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask); + __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1); + __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1); + __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1); + __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1); + __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1); + __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1); + __m256i vec_a_2 = vec_as[k * 4 + 2]; + __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15); + __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15); + __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask); + __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2); + __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2); + __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15); + __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15); + __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask); + __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2); + __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2); + __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2); + __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2); + __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2); + __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2); + __m256i vec_a_3 = vec_as[k * 4 + 3]; + __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15); + __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15); + __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask); + __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3); + __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3); + __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15); + __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15); + __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask); + __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3); + __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3); + __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3); + __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3); + __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3); + __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3); + } + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs)); + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs), vec_gc3); } - __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs)); - __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs)); - __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs)); - __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs)); - vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); - vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); - vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); - vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs), vec_gc0); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs), vec_gc1); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs), vec_gc2); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs), vec_gc3); - } } #endif } @@ -420,62 +437,62 @@ inline int32_t two_tbl_impl{{ pre }}(int32_t* c, int8_t* lut, uint8_t* a) { #ifdef __AVX2__ const __m256i vec_mask = _mm256_set1_epi8(0x0f); const int KK = BK2 / 2; -#pragma unroll + UNROLL_LOOP(BM{{ pre }}) for (int i = 0; i < BM{{ pre }}; i += 32) { __m256i vec_as[KK / 2]; #pragma unroll for (int ai = 0; ai < KK / 2; ai++) { vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); } -#pragma unroll - for (int bs = 0; bs < batch_size; bs++) { - __m256i vec_c0 = _mm256_setzero_si256(); - __m256i vec_c1 = _mm256_setzero_si256(); -#pragma unroll - for (int k = 0; k < KK / 8; k++) { + #pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); #pragma unroll - for (int j = 0; j < 4; j++) { - __m256i vec_a = vec_as[k * 4 + j]; - - __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs)); - __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs)); - __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs)); - __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs)); - - __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask); - __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top); - __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top); - - __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask); - __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot); - __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot); - - __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec); - __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec); - __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec); - __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi); - vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo); - vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); + for (int k = 0; k < KK / 8; k++) { + UNROLL_LOOP(4) + for (int j = 0; j < 4; j++) { + __m256i vec_a = vec_as[k * 4 + j]; + + __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs)); + __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs)); + __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs)); + __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs)); + + __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask); + __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top); + __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top); + + __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask); + __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot); + __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot); + + __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec); + __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); + } } - } - __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs)); - __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs)); - __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs)); - __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs)); + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs)); - vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); - vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); - vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); - vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs), vec_gc0); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs), vec_gc1); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs), vec_gc2); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs), vec_gc3); - } + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM{{ pre }} * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{{ pre }} * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{{ pre }} * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{{ pre }} * bs), vec_gc3); + } } #endif return 0; @@ -485,13 +502,13 @@ template int32_t three_qgemm_lut_{{ pre }}(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { alignas(32) uint32_t CBits[BATCH_SIZE * BM{{ pre }}]; memset(&(CBits[0]), 0, BATCH_SIZE * BM{{ pre }} * sizeof(int32_t)); -#pragma unroll + UNROLL_LOOP({{ k_list_indexed[1] }}) for (int32_t k_outer = 0; k_outer < {{ k_list_indexed[1] }} / BBK{{ pre }}; ++k_outer) { three_tbl_impl_{{ pre }}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{{ pre }} / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{{ pre }} / 3 / 2 * BM{{ pre }})])), (&(((uint8_t*)sign)[(k_outer * BBK{{ pre }} / 3 / 8 * BM{{ pre }})]))); } -#pragma unroll + #pragma unroll for (int bs = 0; bs < BATCH_SIZE; bs++) { -#pragma unroll + UNROLL_LOOP(BM{{ pre }}) for (int i = 0; i < BM{{ pre }}; i++) { ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM{{ pre }}]); } @@ -503,13 +520,13 @@ template int32_t two_qgemm_lut_{{ pre }}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { alignas(32) uint32_t CBits[BATCH_SIZE * BM{{ pre }}]; memset(&(CBits[0]), 0, BATCH_SIZE * BM{{ pre }} * sizeof(int32_t)); -#pragma unroll + UNROLL_LOOP({{ k_list_indexed[0] }}) for (int32_t k_outer = 0; k_outer < {{ k_list_indexed[0] }} / 32; ++k_outer) { two_tbl_impl{{ pre }}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM{{ pre }})]))); } -#pragma unroll + #pragma unroll for (int bs = 0; bs < BATCH_SIZE; bs++) { -#pragma unroll + UNROLL_LOOP(BM{{ pre }}) for (int i = 0; i < BM{{ pre }}; i++) { ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM{{ pre }}]); ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0]; From d466e944b6bc2d546ba42954360e61ba3d80e78d Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Sat, 26 Oct 2024 01:34:14 +0200 Subject: [PATCH 16/19] Reduce repetition --- utils/templates/tl2.h | 36 ++++++++++++------------------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/utils/templates/tl2.h b/utils/templates/tl2.h index 00fbd98b..174a60c4 100644 --- a/utils/templates/tl2.h +++ b/utils/templates/tl2.h @@ -552,34 +552,22 @@ void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT {% for kernel_shape in kernel_shapes %} {% if loop.index0 > 0 %}else {% endif %}if (m == {{ kernel_shapes[loop.index0][0] }} && k == {{ kernel_shapes[loop.index0][1] }}) { if (BK == {{ k_list[loop.index0][0] }}) { - if (bs == 1) { - two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<1>(A, LUT, Scales, LUT_Scales, C); - } else if (bs == 8) { - two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<8>(A, LUT, Scales, LUT_Scales, C); - } else if (bs == 32) { - two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<32>(A, LUT, Scales, LUT_Scales, C); - } else if (bs == 128) { - two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<128>(A, LUT, Scales, LUT_Scales, C); - } else if (bs == 256) { - two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<256>(A, LUT, Scales, LUT_Scales, C); - } else if (bs == 512) { - two_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<512>(A, LUT, Scales, LUT_Scales, C); + {% set block_sizes = [1, 8, 32, 128, 256, 512] %} + {% set outer_loop = loop %} + {% for bs in block_sizes %} + {% if loop.index0 > 0 %}else {% endif %}if (bs == {{ bs }}) { + two_qgemm_lut_{{ kernel_shapes[outer_loop.index0][0] }}_{{ kernel_shapes[outer_loop.index0][1] }}<{{ bs }}>(A, LUT, Scales, LUT_Scales, C); } + {% endfor %} } else if (BK == {{ k_list[loop.index0][1] }}) { - if (bs == 1) { - three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<1>(A, sign, LUT, Scales, LUT_Scales, C); - }else if (bs == 8) { - three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<8>(A, sign, LUT, Scales, LUT_Scales, C); - }else if (bs == 32) { - three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<32>(A, sign, LUT, Scales, LUT_Scales, C); - }else if (bs == 128) { - three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<128>(A, sign, LUT, Scales, LUT_Scales, C); - }else if (bs == 256) { - three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<256>(A, sign, LUT, Scales, LUT_Scales, C); - }else if (bs == 512) { - three_qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}<512>(A, sign, LUT, Scales, LUT_Scales, C); + {% set block_sizes = [1, 8, 32, 128, 256, 512] %} + {% set outer_loop = loop %} + {% for bs in block_sizes %} + {% if loop.index0 > 0 %}else {% endif %}if (bs == {{ bs }}) { + three_qgemm_lut_{{ kernel_shapes[outer_loop.index0][0] }}_{{ kernel_shapes[outer_loop.index0][1] }}<{{ bs }}>(A, sign, LUT, Scales, LUT_Scales, C); } + {% endfor %} } } {% endfor %} From 3034848e208f6ea1172c22326d351b46c847db87 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Sat, 26 Oct 2024 01:40:52 +0200 Subject: [PATCH 17/19] Reduce repetition --- utils/templates/tl2.h | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/utils/templates/tl2.h b/utils/templates/tl2.h index 174a60c4..549904b6 100644 --- a/utils/templates/tl2.h +++ b/utils/templates/tl2.h @@ -184,15 +184,9 @@ inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_t ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); } int8_t* qlut_i8 = reinterpret_cast(qlut); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]); - + {%- for i in range(8) %} + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + {{ i }} * 32 + 0), ix[{{ i }}]); + {%- endfor %} } *lut_scales = scales; @@ -263,15 +257,9 @@ inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_typ int8_t* qlut_i8 = reinterpret_cast(qlut); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]); - + {%- for i in range(8) %} + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + {{ i }} * 32 + 0), ix[{{ i }}]); + {%- endfor %} } *lut_scales = scales; #endif From 47c1dd0d7e03fad6c68c91f50c6ecf8adeab01bb Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Sat, 26 Oct 2024 01:45:15 +0200 Subject: [PATCH 18/19] Revert "Add GCC to compiler check" This reverts commit 9d37b8692d4a587aa699f80805a9579777e18c8c since it's part of another PR. --- src/CMakeLists.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index bac84596..9cead70f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -4,7 +4,7 @@ set(GGML_SOURCES_BITNET ggml-bitnet-lut.cpp) include_directories(3rdparty/llama.cpp/ggml/include) -if (NOT (CMAKE_C_COMPILER_ID MATCHES "Clang" OR CMAKE_C_COMPILER_ID STREQUAL "GNU") OR - NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")) - message(FATAL_ERROR "Clang or GCC is required for Bitnet.cpp compilation") -endif() +if ((NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") OR +(NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")) + message(FATAL_ERROR "Clang is required for Bitnet.cpp compilation") +endif() \ No newline at end of file From bcd594a9fa0193f348fc265b7e7fbbd226af5806 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Sat, 26 Oct 2024 01:46:05 +0200 Subject: [PATCH 19/19] Put back change from main It's part of another PR --- setup_env.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup_env.py b/setup_env.py index 8a9c4b46..b9bf5fc5 100644 --- a/setup_env.py +++ b/setup_env.py @@ -34,6 +34,7 @@ OS_EXTRA_ARGS = { "Windows":["-T", "ClangCL"], + "Linux": ["-DCMAKE_C_COMPILER=clang", "-DCMAKE_CXX_COMPILER=clang++"] } ARCH_ALIAS = {