diff --git a/requirements.txt b/requirements.txt index 3f5c5472..bdff4394 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ -r 3rdparty/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt -r 3rdparty/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt -r 3rdparty/llama.cpp/requirements/requirements-convert_llama_ggml_to_gguf.txt --r 3rdparty/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt \ No newline at end of file +-r 3rdparty/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +jinja2 diff --git a/utils/codegen_tl1.py b/utils/codegen_tl1.py index 4c2e7dd3..7e02b144 100644 --- a/utils/codegen_tl1.py +++ b/utils/codegen_tl1.py @@ -1,371 +1,8 @@ import argparse import os from configparser import ConfigParser - -def gen_ctor_code(): - kernel_code = "\n\ -#include \"ggml-bitnet.h\"\n\ -#define GGML_BITNET_MAX_NODES 8192\n\ -static bool initialized = false;\n\ -static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;\n\ -static size_t bitnet_tensor_extras_index = 0;\n\ -static void * aligned_malloc(size_t size) {{\n\ -#if defined(_WIN32)\n\ - return _aligned_malloc(size, 64);\n\ -#else\n\ - void * ptr = nullptr;\n\ - posix_memalign(&ptr, 64, size);\n\ - return ptr;\n\ -#endif\n\ -}}\n\ -static void aligned_free(void * ptr) {{\n\ -#if defined(_WIN32)\n\ - _aligned_free(ptr);\n\ -#else\n\ - free(ptr);\n\ -#endif\n\ -}}\n\ -\n\ -void per_tensor_quant(int k, void* lut_scales_, void* b_) {{\n\ - bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\ - bitnet_float_type* b = (bitnet_float_type*)b_;\n\ -#ifdef __ARM_NEON\n\ - float32x4_t temp_max = vdupq_n_f32(0);\n\ - for (int i=0; i < k / 4; i++) {{\n\ - float32x4_t vec_bs = vld1q_f32(b + 4 * i);\n\ - float32x4_t abssum = vabsq_f32(vec_bs);\n\ - temp_max = vmaxq_f32(abssum, temp_max);\n\ - }}\n\ - float32_t scales = 127 / vmaxvq_f32(temp_max);\n\ - *lut_scales = scales;\n\ -#elif defined __AVX2__\n\ - __m256 max_vec = _mm256_set1_ps(0.f);\n\ - const __m256 vec_sign = _mm256_set1_ps(-0.0f);\n\ - // #pragma unroll\n\ - for (int i = 0; i < k / 8; i++) {{\n\ - __m256 vec_b = _mm256_loadu_ps(b + i * 8);\n\ - __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);\n\ - max_vec = _mm256_max_ps(vec_babs, max_vec);\n\ - }}\n\ - __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));\n\ - max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));\n\ - max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));\n\ - float scales = 127 / _mm_cvtss_f32(max1);\n\ - *lut_scales = scales;\n\ -#endif\n\ -}}\n\ -\n\ -void partial_max_reset(void* lut_scales_) {{\n\ - bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\ - *lut_scales = 0.0;\n\ -}}\n\ -\n\ -#ifdef __ARM_NEON\n\ -inline void Transpose_8_8(\n\ - int16x8_t *v0,\n\ - int16x8_t *v1,\n\ - int16x8_t *v2,\n\ - int16x8_t *v3,\n\ - int16x8_t *v4,\n\ - int16x8_t *v5,\n\ - int16x8_t *v6,\n\ - int16x8_t *v7)\n\ -{{\n\ - int16x8x2_t q04 = vzipq_s16(*v0, *v4);\n\ - int16x8x2_t q15 = vzipq_s16(*v1, *v5);\n\ - int16x8x2_t q26 = vzipq_s16(*v2, *v6);\n\ - int16x8x2_t q37 = vzipq_s16(*v3, *v7);\n\ -\n\ - int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);\n\ - int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);\n\ - int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);\n\ - int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);\n\ -\n\ - int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);\n\ - int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);\n\ - int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);\n\ - int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);\n\ -\n\ - *v0 = q_fin_0.val[0];\n\ - *v1 = q_fin_0.val[1];\n\ - *v2 = q_fin_1.val[0];\n\ - *v3 = q_fin_1.val[1];\n\ - *v4 = q_fin_2.val[0];\n\ - *v5 = q_fin_2.val[1];\n\ - *v6 = q_fin_3.val[0];\n\ - *v7 = q_fin_3.val[1];\n\ -}}\n\ -#endif\n\ -\n\ -template\n\ -inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{\n\ -#ifdef __ARM_NEON\n\ - int16x8_t vec_lut[16];\n\ - float32_t scales = *lut_scales;\n\ - uint8_t tbl_mask[16];\n\ - tbl_mask[0] = 0;\n\ - tbl_mask[1] = 2;\n\ - tbl_mask[2] = 4;\n\ - tbl_mask[3] = 6;\n\ - tbl_mask[4] = 8;\n\ - tbl_mask[5] = 10;\n\ - tbl_mask[6] = 12;\n\ - tbl_mask[7] = 14;\n\ - tbl_mask[8] = 1;\n\ - tbl_mask[9] = 3;\n\ - tbl_mask[10] = 5;\n\ - tbl_mask[11] = 7;\n\ - tbl_mask[12] = 9;\n\ - tbl_mask[13] = 11;\n\ - tbl_mask[14] = 13;\n\ - tbl_mask[15] = 15;\n\ - uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);\n\ -#pragma unroll\n\ - for (int k = 0; k < act_k / 16; ++k) {{\n\ - float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);\n\ - float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);\n\ - float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);\n\ - float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);\n\ - float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);\n\ - float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);\n\ - int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);\n\ - int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);\n\ - int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);\n\ - int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);\n\ - int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);\n\ - int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);\n\ - int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);\n\ - int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);\n\ - int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);\n\ - int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);\n\ - vec_lut[0] = vdupq_n_s16(0);\n\ - vec_lut[0] = vec_lut[0] - vec_bs_0;\n\ - vec_lut[0] = vec_lut[0] - vec_bs_1;\n\ - vec_lut[1] = vdupq_n_s16(0);\n\ - vec_lut[1] = vec_lut[1] - vec_bs_0;\n\ - vec_lut[2] = vdupq_n_s16(0);\n\ - vec_lut[2] = vec_lut[2] - vec_bs_0;\n\ - vec_lut[2] = vec_lut[2] + vec_bs_1;\n\ - vec_lut[3] = vdupq_n_s16(0);\n\ - vec_lut[3] = vec_lut[3] - vec_bs_1;\n\ - vec_lut[4] = vdupq_n_s16(0);\n\ - vec_lut[5] = vec_bs_1;\n\ - vec_lut[6] = vec_bs_0;\n\ - vec_lut[6] = vec_lut[6] - vec_bs_1;\n\ - vec_lut[7] = vec_bs_0;\n\ - vec_lut[8] = vec_bs_0;\n\ - vec_lut[8] = vec_lut[8] + vec_bs_1;\n\ - Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),\n\ - &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));\n\ - Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),\n\ - &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));\n\ -#pragma unroll\n\ - for (int idx = 0; idx < 8; idx++) {{\n\ - int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);\n\ - int8x8_t q0_low = vget_low_s8(q0_s);\n\ - int8x8_t q0_high = vget_high_s8(q0_s);\n\ - int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);\n\ - int8x8_t q1_low = vget_low_s8(q1_s);\n\ - int8x8_t q1_high = vget_high_s8(q1_s);\n\ - vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);\n\ - vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);\n\ - vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);\n\ - vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);\n\ - }}\n\ - }}\n\ -#endif\n\ -}}\n\ -\n\ -static bool is_type_supported(enum ggml_type type) {{\n\ - if (type == GGML_TYPE_Q4_0 ||\n\ - type == GGML_TYPE_TL1) {{\n\ - return true;\n\ - }} else {{\n\ - return false;\n\ - }}\n\ -}}\n\ -" - return kernel_code - -def gen_body_core_code(bm, by): - length = 4 - all_code = "" - for i in range(length): - core_code = "\n\ - uint8x16_t vec_a_{0} = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + {0} * 16);\n\ - uint8x16_t vec_a{0}_top = vshrq_n_u8(vec_a_{0}, 4);\n\ - uint8x16_t vec_a{0}_bot = vandq_u8(vec_a_{0}, vec_mask);\n\ - int8x16_t vec_v_{0}_left_tmp0 = vqtbl1q_s8(vec_lut[{1} * k + {2}], vec_a{0}_top);\n\ - int8x16_t vec_v_{0}_left_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {3}], vec_a{0}_top);\n\ - int8x16_t vec_v_{0}_right_tmp0 = vqtbl1q_s8(vec_lut[{1} * k + {4}], vec_a{0}_bot);\n\ - int8x16_t vec_v_{0}_right_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {5}], vec_a{0}_bot);\n\ - int8x16x2_t vec_v_left_{0} = vzipq_s8(vec_v_{0}_left_tmp1, vec_v_{0}_left_tmp0);\n\ - int8x16x2_t vec_v_right_{0} = vzipq_s8(vec_v_{0}_right_tmp1, vec_v_{0}_right_tmp0);\n\ - vec_c[{6}] += vec_v_left_{0}.val[0];\n\ - vec_c[{6}] += vec_v_right_{0}.val[0];\n\ - vec_c[{7}] += vec_v_left_{0}.val[1];\n\ - vec_c[{7}] += vec_v_right_{0}.val[1];\n\ - ".format(i, 2 * by // 2, (4 * i) % (2 * by // 2), (4 * i + 1) % (2 * by // 2), (4 * i + 2) % (2 * by // 2), (4 * i + 3) % (2 * by // 2), (i * 2) // (by // 2) * 2 + 0, (i * 2) // (by // 2) * 2 + 1) - - all_code = "".join([all_code, core_code]) - - all_code = "".join([all_code, "\n }\n\n"]) - - for i in range(bm // 8): - core_code = "\ - int32x4_t vec_v_bot_low_low_{0} = vmovl_s16(vget_low_s16(vec_c[{0}]));\n\ - int32x4_t vec_v_bot_low_high_{0} = vmovl_high_s16(vec_c[{0}]);\n\ - vst1q_s32(c + i + {1}, vld1q_s32(c + i + {1}) + vec_v_bot_low_low_{0});\n\ - vst1q_s32(c + i + {2}, vld1q_s32(c + i + {2}) + vec_v_bot_low_high_{0});\n".format(i, i * 8, i * 8 + 4) - all_code = "".join([all_code, core_code]) - - return all_code - -def gen_tbl_impl(pre, BM, BK, bm, k): - - kernel_code = "\ -#include \n\ -\n\ -#define BM{0} {1}\n\ -#define BBK{0} {2}\n\ -inline void tbl_impl_{0}(int32_t* c, int8_t* lut, uint8_t* a) {{\n\ -#ifdef __ARM_NEON\n\ - const int KK = BBK{0} / 2;\n\ - const uint8x16_t vec_mask = vdupq_n_u8(0x0f);\n\ - const int8x16_t vec_zero = vdupq_n_s16(0x0000);\n\ - int8x16_t vec_lut[2 * KK];\n\ -".format(pre, BM, BK) - - kernel_code = "".join([kernel_code, " int16x8_t vec_c[{}];".format(bm // 8)]) - - kernel_code = "".join([kernel_code, "\n\ -#pragma unroll\n\ - for (int k = 0; k < 2 * KK; k++) {\n\ - vec_lut[k] = vld1q_s8(lut + k * 16);\n\ - }\n"]) - - pre_core_code = "\n\ -#pragma unroll\n\ - for (int i = 0; i < BM{}; i += {}) {{\n\ - #pragma unroll\n\ - for (int i=0; i<{}; i++) {{\n\ - vec_c[i] = vandq_s16(vec_c[i], vec_zero);\n\ - }}\n".format(pre, bm, bm // 8) - - body_core_pre_code = "\n\ -#pragma unroll\n\ - for (int k = 0; k < KK / {}; k++) {{\n\ - ".format(256 // bm // 2) - - body_core_post_code = "\n\ - }\n\ -\ -#endif\n\ -}\n" - - kernel_code = "".join([kernel_code, pre_core_code, body_core_pre_code, gen_body_core_code(bm, 256 // bm), body_core_post_code]) - - kernel_code = "".join([kernel_code, "\n\ -int32_t qgemm_lut_{0}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\ - alignas({1}) uint32_t CBits[BM{0}];\n\ - memset(&(CBits[0]), 0, BM{0} * sizeof(int32_t));\n\ -#pragma unroll\n\ - for (int32_t k_outer = 0; k_outer < {2} / BBK{0}; ++k_outer) {{\n\ - tbl_impl_{0}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{0} / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{0} / 2 / 2 * BM{0})])));\n\ - }}\n\ -#pragma unroll\n\ - for (int i = 0; i < BM{0}; i++) {{\n\ - ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];\n\ - }}\n\ - return 0;\n\ -}};\n".format(pre, min(32, BK), k)]) - - return kernel_code - -def gen_top_api(kernel_shapes): - - kernel_code = "void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {{\n\ - if (m == {0} && k == {1}) {{\n\ - preprocessor_k<{1}>(B, LUT_Scales, QLUT);\n\ - }}\n\ -".format(kernel_shapes[0][0], kernel_shapes[0][1]) - for i in range(1, len(kernel_shapes)): - kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\ - preprocessor_k<{1}>(B, LUT_Scales, QLUT);\n\ - }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])]) - kernel_code = "".join([kernel_code, "}\n"]) - kernel_code = "".join([kernel_code, "void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\ - if (m == {0} && k == {1}) {{\n\ - qgemm_lut_{0}_{1}(A, LUT, Scales, LUT_Scales, C);\n\ - }}\n\ -".format(kernel_shapes[0][0], kernel_shapes[0][1])]) - for i in range(1, len(kernel_shapes)): - kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\ - qgemm_lut_{0}_{1}(A, LUT, Scales, LUT_Scales, C);\n\ - }}\n\ -".format(kernel_shapes[i][0], kernel_shapes[i][1])]) - kernel_code = "".join([kernel_code, "}\n"]) - return kernel_code - -def gen_preprocess_code(): - kernel_code = "\n\ -template\n\ -void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{\n\ - partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));\n\ - per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));\n\ - \n\ - lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));\n\ -}}\n" - return kernel_code - -def gen_transform_code(kernel_shape): - kernel_code = "\n\ -void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {\n\ - if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {\n\ - return;\n\ - }\n\ -\n\ - int k = tensor->ne[0];\n\ - int m = tensor->ne[1];\n\ - const int lut_scales_size = 1;\n\ - const int scales_size = 1;\n\ - int bk = 0;\n\ - int bm = 0;\n" - - kernel_code = "".join([kernel_code, "\n\ - if (m == {0} && k == {1}) {{\n\ - bm = BM{0}_{1};\n\ - bk = BBK{0}_{1};\n\ - }}\n".format(kernel_shapes[0][0], kernel_shapes[0][1])]) - - for i in range(1, len(kernel_shapes)): - kernel_code = "".join([kernel_code, "else if (m == {0} && k == {1}) {{\n\ - bm = BM{0}_{1};\n\ - bk = BBK{0}_{1};\n\ - }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])]) - - kernel_code = "".join([kernel_code, "\n\ - const int n_tile_num = m / bm;\n\ - const int BK = bk;\n\ - uint8_t * qweights;\n\ - bitnet_float_type * scales;\n\ -\n\ - scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));\n\ - qweights = (uint8_t *) tensor->data;\n\ - float * i2_scales = (float * )(qweights + k * m / 4);\n\ - scales[0] = (bitnet_float_type) i2_scales[0];\n\ -\n\ - tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;\n\ - bitnet_tensor_extras[bitnet_tensor_extras_index++] = {\n\ - /* .lut_scales_size = */ lut_scales_size,\n\ - /* .BK = */ BK,\n\ - /* .n_tile_num = */ n_tile_num,\n\ - /* .qweights = */ qweights,\n\ - /* .scales = */ scales\n\ - };\n\ -}\n"]) - - return kernel_code +from pathlib import Path +from jinja2 import Environment, FileSystemLoader if __name__ == "__main__": ModelShapeDict = { @@ -405,28 +42,23 @@ def gen_transform_code(kernel_shape): assert kernel_shapes[i][1] % BK_list[i] == 0, "K %% BK should be 0" assert bm_list[i] in [32, 64], "choose bm from [32, 64]" - tbl_impl_code = [] - - for i in range(len(kernel_shapes)): - tbl_impl_code.append( - gen_tbl_impl("{}_{}".format(kernel_shapes[i][0], kernel_shapes[i][1]), BM_list[i], BK_list[i], bm_list[i], kernel_shapes[i][1]) - ) - api_code = gen_top_api(kernel_shapes) - pre_code = gen_preprocess_code() - ctor_code = gen_ctor_code() - trans_code = gen_transform_code(kernel_shapes) - - output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "include") - - with open(''.join([output_dir, "/bitnet-lut-kernels.h"]), 'w') as f: - f.write(''.join("#if defined(GGML_BITNET_ARM_TL1)")) - f.write(''.join(ctor_code)) - for code in tbl_impl_code: - f.write(''.join(code)) - f.write(''.join(pre_code)) - f.write(''.join(api_code)) - f.write(''.join(trans_code)) - f.write(''.join("#endif")) + env = Environment( + loader=FileSystemLoader(Path(__file__).parent / "templates"), + ) + template = env.get_template("tl1.h") + contents = template.render( + kernel_shapes=kernel_shapes, + k_list=[item[1] for item in kernel_shapes], + BM_list=BM_list, + BK_list=BK_list, + bm_list=bm_list, + min=min, + range=range, + ) + output_dir = Path(__file__).parent.parent / "include" + + with open(str(output_dir / "bitnet-lut-kernels.h"), 'w') as f: + f.write(contents) config = ConfigParser() @@ -438,5 +70,5 @@ def gen_transform_code(kernel_shape): config.set('Kernels_{}'.format(i), 'BK'.format(i), str(BK_list[i])) config.set('Kernels_{}'.format(i), 'bmm'.format(i), str(bm_list[i])) - with open(''.join([output_dir, "/kernel_config.ini"]), 'w') as configfile: - config.write(configfile) \ No newline at end of file + with open(str(output_dir / "kernel_config.ini"), 'w') as configfile: + config.write(configfile) diff --git a/utils/templates/tl1.h b/utils/templates/tl1.h new file mode 100644 index 00000000..cb7c702b --- /dev/null +++ b/utils/templates/tl1.h @@ -0,0 +1,318 @@ +#if defined(GGML_BITNET_ARM_TL1) +#include +#include +#include +#include "ggml-bitnet.h" +static constexpr auto GGML_BITNET_MAX_NODES = 8192; +static bool initialized = false; +static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; +static size_t bitnet_tensor_extras_index = 0; +static void * aligned_malloc(size_t size) { +#if defined(_WIN32) + return _aligned_malloc(size, 64); +#else + void * ptr = nullptr; + posix_memalign(&ptr, 64, size); + return ptr; +#endif +} +static void aligned_free(void * ptr) { +#if defined(_WIN32) + _aligned_free(ptr); +#else + free(ptr); +#endif +} + +void per_tensor_quant(int k, void* lut_scales_, void* b_) { + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + bitnet_float_type* b = (bitnet_float_type*)b_; +#ifdef __ARM_NEON + float32x4_t temp_max = vdupq_n_f32(0); + for (int i=0; i < k / 4; i++) { + float32x4_t vec_bs = vld1q_f32(b + 4 * i); + float32x4_t abssum = vabsq_f32(vec_bs); + temp_max = vmaxq_f32(abssum, temp_max); + } + float32_t scales = 127 / vmaxvq_f32(temp_max); + *lut_scales = scales; +#elif defined __AVX2__ + __m256 max_vec = _mm256_set1_ps(0.f); + const __m256 vec_sign = _mm256_set1_ps(-0.0f); + // #pragma unroll + for (int i = 0; i < k / 8; i++) { + __m256 vec_b = _mm256_loadu_ps(b + i * 8); + __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b); + max_vec = _mm256_max_ps(vec_babs, max_vec); + } + __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec)); + max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1)); + max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1)); + float scales = 127 / _mm_cvtss_f32(max1); + *lut_scales = scales; +#endif +} + +void partial_max_reset(void* lut_scales_) { + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + *lut_scales = 0.0; +} + +#ifdef __ARM_NEON +inline void Transpose_8_8( + int16x8_t *v0, + int16x8_t *v1, + int16x8_t *v2, + int16x8_t *v3, + int16x8_t *v4, + int16x8_t *v5, + int16x8_t *v6, + int16x8_t *v7) +{ + int16x8x2_t q04 = vzipq_s16(*v0, *v4); + int16x8x2_t q15 = vzipq_s16(*v1, *v5); + int16x8x2_t q26 = vzipq_s16(*v2, *v6); + int16x8x2_t q37 = vzipq_s16(*v3, *v7); + + int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]); + int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]); + int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]); + int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]); + + int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]); + int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]); + int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]); + int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]); + + *v0 = q_fin_0.val[0]; + *v1 = q_fin_0.val[1]; + *v2 = q_fin_1.val[0]; + *v3 = q_fin_1.val[1]; + *v4 = q_fin_2.val[0]; + *v5 = q_fin_2.val[1]; + *v6 = q_fin_3.val[0]; + *v7 = q_fin_3.val[1]; +} +#endif + +template +inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) { +#ifdef __ARM_NEON + int16x8_t vec_lut[16]; + float32_t scales = *lut_scales; + uint8_t tbl_mask[16]; + tbl_mask[0] = 0; + tbl_mask[1] = 2; + tbl_mask[2] = 4; + tbl_mask[3] = 6; + tbl_mask[4] = 8; + tbl_mask[5] = 10; + tbl_mask[6] = 12; + tbl_mask[7] = 14; + tbl_mask[8] = 1; + tbl_mask[9] = 3; + tbl_mask[10] = 5; + tbl_mask[11] = 7; + tbl_mask[12] = 9; + tbl_mask[13] = 11; + tbl_mask[14] = 13; + tbl_mask[15] = 15; + uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask); + #pragma unroll + for (int k = 0; k < act_k / 16; ++k) { + float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16); + float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8); + float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales); + float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales); + float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales); + float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales); + int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0); + int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1); + int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2); + int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3); + int16x4_t vec_b16_0 = vmovn_s32(vec_b_0); + int16x4_t vec_b16_1 = vmovn_s32(vec_b_1); + int16x4_t vec_b16_2 = vmovn_s32(vec_b_2); + int16x4_t vec_b16_3 = vmovn_s32(vec_b_3); + int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2); + int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3); + vec_lut[0] = vdupq_n_s16(0); + vec_lut[0] = vec_lut[0] - vec_bs_0; + vec_lut[0] = vec_lut[0] - vec_bs_1; + vec_lut[1] = vdupq_n_s16(0); + vec_lut[1] = vec_lut[1] - vec_bs_0; + vec_lut[2] = vdupq_n_s16(0); + vec_lut[2] = vec_lut[2] - vec_bs_0; + vec_lut[2] = vec_lut[2] + vec_bs_1; + vec_lut[3] = vdupq_n_s16(0); + vec_lut[3] = vec_lut[3] - vec_bs_1; + vec_lut[4] = vdupq_n_s16(0); + vec_lut[5] = vec_bs_1; + vec_lut[6] = vec_bs_0; + vec_lut[6] = vec_lut[6] - vec_bs_1; + vec_lut[7] = vec_bs_0; + vec_lut[8] = vec_bs_0; + vec_lut[8] = vec_lut[8] + vec_bs_1; + Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]), + &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7])); + Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]), + &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15])); + #pragma unroll + for (int idx = 0; idx < 8; idx++) { + int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q); + int8x8_t q0_low = vget_low_s8(q0_s); + int8x8_t q0_high = vget_high_s8(q0_s); + int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q); + int8x8_t q1_low = vget_low_s8(q1_s); + int8x8_t q1_high = vget_high_s8(q1_s); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low); + } + } +#endif +} + +static bool is_type_supported(enum ggml_type type) { + if (type == GGML_TYPE_Q4_0 || + type == GGML_TYPE_TL1) { + return true; + } else { + return false; + } +} + +{% for kernel_shape in kernel_shapes -%} +{%- set pre = kernel_shape[0] ~ "_" ~ kernel_shape[1] -%} +{%- set BM = BM_list[loop.index0] -%} +{%- set BK = BK_list[loop.index0] -%} +{%- set bm = bm_list[loop.index0] -%} +{%- set by = 256 // bm -%} +{%- set k_list_indexed = k_list[loop.index0] -%} +{%- set length = 4 -%} + +static constexpr auto BM{{ pre }} = {{ BM }}; +static constexpr auto BBK{{ pre }} = {{ BK }}; +inline void tbl_impl_{{ pre }}(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const int KK = BBK{{ pre }} / 2; + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + const int8x16_t vec_zero = vdupq_n_s16(0x0000); + int8x16_t vec_lut[2 * KK]; + int16x8_t vec_c[{{ bm // 8 }}]; + #pragma unroll + for (int k = 0; k < 2 * KK; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + + #pragma unroll + for (int i = 0; i < BM{{ pre }}; i += {{ bm }}) { + #pragma unroll + for (int i=0; i<{{ bm // 8 }}; i++) { + vec_c[i] = vandq_s16(vec_c[i], vec_zero); + } + + #pragma unroll + for (int k = 0; k < KK / {{ 256 // bm // 2 }}; k++) { +{% for index in range(length) %} + uint8x16_t vec_a_{{ index }} = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + {{ index }} * 16); + uint8x16_t vec_a{{ index }}_top = vshrq_n_u8(vec_a_{{ index }}, 4); + uint8x16_t vec_a{{ index }}_bot = vandq_u8(vec_a_{{ index }}, vec_mask); + int8x16_t vec_v_{{ index }}_left_tmp0 = vqtbl1q_s8(vec_lut[{{ 2 * by // 2 }} * k + {{ (4 * index) % (2 * by // 2) }}], vec_a{{ index }}_top); + int8x16_t vec_v_{{ index }}_left_tmp1 = vqtbl1q_s8(vec_lut[{{ 2 * by // 2 }} * k + {{ (4 * index + 1) % (2 * by // 2) }}], vec_a{{ index }}_top); + int8x16_t vec_v_{{ index }}_right_tmp0 = vqtbl1q_s8(vec_lut[{{ 2 * by // 2 }} * k + {{ (4 * index + 2) % (2 * by // 2) }}], vec_a{{ index }}_bot); + int8x16_t vec_v_{{ index }}_right_tmp1 = vqtbl1q_s8(vec_lut[{{ 2 * by // 2 }} * k + {{ (4 * index + 3) % (2 * by // 2) }}], vec_a{{ index }}_bot); + int8x16x2_t vec_v_left_{{ index }} = vzipq_s8(vec_v_{{ index }}_left_tmp1, vec_v_{{ index }}_left_tmp0); + int8x16x2_t vec_v_right_{{ index }} = vzipq_s8(vec_v_{{ index }}_right_tmp1, vec_v_{{ index }}_right_tmp0); + vec_c[{{ (index * 2) // (by // 2) * 2 + 0 }}] += vec_v_left_{{ index }}.val[0]; + vec_c[{{ (index * 2) // (by // 2) * 2 + 0 }}] += vec_v_right_{{ index }}.val[0]; + vec_c[{{ (index * 2) // (by // 2) * 2 + 1 }}] += vec_v_left_{{ index }}.val[1]; + vec_c[{{ (index * 2) // (by // 2) * 2 + 1 }}] += vec_v_right_{{ index }}.val[1]; +{% endfor %} + } +{% for index in range(bm // 8) %} + int32x4_t vec_v_bot_low_low_{{ index }} = vmovl_s16(vget_low_s16(vec_c[{{ index }}])); + int32x4_t vec_v_bot_low_high_{{ index }} = vmovl_high_s16(vec_c[{{ index }}]); + vst1q_s32(c + i + {{ index * 8 }}, vld1q_s32(c + i + {{ index * 8 }}) + vec_v_bot_low_low_{{ index }}); + vst1q_s32(c + i + {{ index * 8 + 4 }}, vld1q_s32(c + i + {{ index * 8 + 4 }}) + vec_v_bot_low_high_{{ index }}); +{% endfor %} + } +#endif +} + +int32_t qgemm_lut_{{ pre }}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas({{ min(32, BK) }}) uint32_t CBits[BM{{ pre }}]; + memset(&(CBits[0]), 0, BM{{ pre }} * sizeof(int32_t)); + #pragma unroll + for (int32_t k_outer = 0; k_outer < {{ k_list_indexed }} / BBK{{ pre }}; ++k_outer) { + tbl_impl_{{ pre }}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{{ pre }} / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{{ pre }} / 2 / 2 * BM{{ pre }})]))); + } + #pragma unroll + for (int i = 0; i < BM{{ pre }}; i++) { + ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; + } + return 0; +}; +{% endfor %} + +template +void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) { + partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0]))); + per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0]))); + + lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0]))); +} + +void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) { +{% for kernel_shape in kernel_shapes %} + {% if loop.index0 > 0 %}else {% endif %}if (m == {{ kernel_shapes[loop.index0][0] }} && k == {{ kernel_shapes[loop.index0][1] }}) { + preprocessor_k<{{ kernel_shapes[loop.index0][1] }}>(B, LUT_Scales, QLUT); + } +{% endfor %} +} +void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { +{% for kernel_shape in kernel_shapes %} + {% if loop.index0 > 0 %}else {% endif %}if (m == {{ kernel_shapes[loop.index0][0] }} && k == {{ kernel_shapes[loop.index0][1] }}) { + qgemm_lut_{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}(A, LUT, Scales, LUT_Scales, C); + } +{% endfor %} +} + +void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) { + if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { + return; + } + + int k = tensor->ne[0]; + int m = tensor->ne[1]; + const int lut_scales_size = 1; + const int scales_size = 1; + int bk = 0; + int bm = 0; +{% for kernel_shape in kernel_shapes %} + {% if loop.index0 > 0 %}else {% endif %}if (m == {{ kernel_shapes[loop.index0][0] }} && k == {{ kernel_shapes[loop.index0][1] }}) { + bm = BM{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}; + bk = BBK{{ kernel_shapes[loop.index0][0] }}_{{ kernel_shapes[loop.index0][1] }}; + } +{% endfor %} + const int n_tile_num = m / bm; + const int BK = bk; + uint8_t * qweights; + bitnet_float_type * scales; + + scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type)); + qweights = (uint8_t *) tensor->data; + float * i2_scales = (float * )(qweights + k * m / 4); + scales[0] = (bitnet_float_type) i2_scales[0]; + + tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index; + bitnet_tensor_extras[bitnet_tensor_extras_index++] = { + /* .lut_scales_size = */ lut_scales_size, + /* .BK = */ BK, + /* .n_tile_num = */ n_tile_num, + /* .qweights = */ qweights, + /* .scales = */ scales + }; +} +#endif