From 351e345c6f1bdcf62d4c6c95dc67d6d8ccdd2bf1 Mon Sep 17 00:00:00 2001 From: Qingtao Li Date: Thu, 10 Oct 2024 16:35:37 +0800 Subject: [PATCH 01/12] Integrate T-MAC kernels --- convert_hf_to_gguf.py | 129 +++++++++++- ggml/CMakeLists.txt | 3 + ggml/include/ggml-tmac.h | 42 ++++ ggml/include/ggml.h | 4 + ggml/src/CMakeLists.txt | 57 +++++- ggml/src/ggml-cpu.c | 166 +++++++++++++++ ggml/src/ggml-quants.c | 6 + ggml/src/ggml-tmac.cpp | 398 ++++++++++++++++++++++++++++++++++++ ggml/src/ggml.c | 44 ++++ gguf-py/gguf/constants.py | 14 ++ gguf-py/gguf/gguf_writer.py | 4 +- gguf-py/gguf/quants.py | 9 + include/llama.h | 1 + src/llama.cpp | 33 ++- 14 files changed, 899 insertions(+), 11 deletions(-) create mode 100644 ggml/include/ggml-tmac.h create mode 100644 ggml/src/ggml-tmac.cpp diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 76ee6cef52ac0..6f31fc87815b3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -65,6 +65,9 @@ class Model: model_name: str | None metadata_override: Path | None dir_model_card: Path + is_lora: bool + enable_t_mac: bool + kcfg_file: Path | None # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -73,7 +76,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, - small_first_shard: bool = False, hparams: dict[str, Any] | None = None): + small_first_shard: bool = False, hparams: dict[str, Any] | None = None, is_lora: bool = False, + enable_t_mac: bool = False, kcfg_file: Path | None = None): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") @@ -95,7 +99,9 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.metadata_override = metadata_override self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py - + self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py + self.enable_t_mac = enable_t_mac + self.kcfg_file = kcfg_file # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type if self.ftype == gguf.LlamaFileType.GUESSED: # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie. @@ -265,6 +271,73 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] + _gptq_quant_dict: dict[str, Tensor] | None = None + _t_mac_bits: int = 0 + _t_mac_raw_shape: tuple[int, ...] | None = None + + # Repack and merge qweight, scales, and qzeros into a single tensor + # Currently, this logic is nearly impossible to be implemented in quants.py + def _modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if not self.enable_t_mac: + return self.modify_tensors(data_torch, name, bid) + + # bits = 0 means not quantized + self._t_mac_bits = 0 + self._t_mac_raw_shape = None + + from t_mac.model_utils import get_quantization_config, preprocess_for_t_mac + quantization_config = get_quantization_config(self.dir_model) + + if quantization_config["quant_method"] == "gptq": # AutoGPTQ/GPTQModel + if name.endswith(".g_idx"): + return [] + + if name.endswith(".qweight") or name.endswith(".scales") or name.endswith(".qzeros"): + if self._gptq_quant_dict is None: + self._gptq_quant_dict = {} + suffix = "." + name.split(".")[-1] + base_name = name.replace(suffix, "") + self._gptq_quant_dict.setdefault(base_name, {})[suffix] = data_torch + if len(self._gptq_quant_dict[base_name]) < 3: + return [] + + qweight = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".qweight"]).numpy() + scales = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".scales"]).numpy() + qzeros = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".qzeros"]).numpy() + name = base_name + ".weight" + from t_mac.model_utils import unpack_gptqv2 + w, scales, zeros, bits, group_size = unpack_gptqv2(qweight, scales, qzeros, "gptqmodel" in quantization_config["quantizer"]) + self._t_mac_bits = bits + self._t_mac_raw_shape = w.shape + if bits != quantization_config["bits"] or group_size != quantization_config["group_size"]: + logger.warning("Error while parsing weights for quantization_config: {}".format(quantization_config)) + + # For permutation in, e.g., LlamaModel + w = self.modify_tensors(torch.from_numpy(w), name, bid)[0][1].numpy() + scales = self.modify_tensors(torch.from_numpy(scales), name, bid)[0][1].numpy() + zeros = self.modify_tensors(torch.from_numpy(zeros), name, bid)[0][1].numpy() + + if self.ftype == gguf.LlamaFileType.MOSTLY_INT_N: + if quantization_config["sym"]: + if not np.allclose(zeros, np.zeros_like(zeros)): + logger.warning("Although the quantized model claimed to be symmetric, the weights are asymmetric") + else: + zeros = None + data_torch = torch.from_numpy(preprocess_for_t_mac(self.kcfg_file, w, scales, zeros, bits=bits)) + else: + old_shape = w.shape + w = w.astype("float32").reshape(-1, group_size) + scales = scales.astype("float32").reshape(-1, 1) + zeros = zeros.astype("float32").reshape(-1, 1) + data = (w - (zeros / scales + (2 ** (bits - 1)))) * scales + data_torch = torch.from_numpy(data.reshape(old_shape)) + if self.ftype == gguf.LlamaFileType.MOSTLY_F16: + data_torch = data_torch.to(torch.float16) + else: + return self.modify_tensors(data_torch, name, bid) + + return [(self.map_tensor_name(name), data_torch)] + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid, n_dims # unused @@ -285,7 +358,7 @@ def prepare_tensors(self): old_dtype = data_torch.dtype # convert any unsupported data types to float32 - if data_torch.dtype not in (torch.float16, torch.float32): + if data_torch.dtype not in (torch.float16, torch.float32) and not self.enable_t_mac: data_torch = data_torch.to(torch.float32) # use the first number-like part of the tensor name as the block id @@ -295,7 +368,7 @@ def prepare_tensors(self): bid = int(part) break - for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): + for new_name, data_torch in (self._modify_tensors(data_torch, name, bid)): data = data_torch.squeeze().numpy() # if data ends up empty, it means data_torch was a scalar tensor -> restore @@ -344,6 +417,19 @@ def prepare_tensors(self): # TODO: use Q4_K and Q6_K data_qtype = gguf.GGMLQuantizationType.F16 + # If self._t_mac_bits > 0, the tensor is quantized by GPTQ + if self.enable_t_mac and self._t_mac_bits > 0 and self.ftype == gguf.LlamaFileType.MOSTLY_INT_N: + if self._t_mac_bits == 1: + data_qtype = gguf.GGMLQuantizationType.I1 + elif self._t_mac_bits == 2: + data_qtype = gguf.GGMLQuantizationType.I2 + elif self._t_mac_bits == 3: + data_qtype = gguf.GGMLQuantizationType.I3 + elif self._t_mac_bits == 4: + data_qtype = gguf.GGMLQuantizationType.I4 + else: + raise ValueError(f"Unsupported number of bits: {self._t_mac_bits}") + # No override (data_qtype is False), or wants to be quantized (data_qtype is True) if isinstance(data_qtype, bool): if self.ftype == gguf.LlamaFileType.ALL_F32: @@ -358,6 +444,12 @@ def prepare_tensors(self): data_qtype = gguf.GGMLQuantizationType.TQ1_0 elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0: data_qtype = gguf.GGMLQuantizationType.TQ2_0 + elif self.ftype == gguf.LlamaFileType.MOSTLY_INT_N: + # If the tensor is successfully quantized, data_qtype should be I1/2/3/4 + # If data_qtype is still bool, then the tensor should not be quantized + # In practice, this tensor is `output.weight` for GPTQ models + # TODO: Consider quantizing it? + data_qtype = gguf.GGMLQuantizationType.F16 else: raise ValueError(f"Unknown file type: {self.ftype.name}") @@ -369,6 +461,7 @@ def prepare_tensors(self): data = gguf.quants.quantize(data, data_qtype) shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape + shape = self._t_mac_raw_shape or shape # reverse shape to make it similar to the internal ggml dimension order shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" @@ -376,7 +469,8 @@ def prepare_tensors(self): # n_dims is implicit in the shape logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") - self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype) + raw_shape = gguf.quant_shape_to_byte_shape(self._t_mac_raw_shape, data_qtype) if self.ftype == gguf.LlamaFileType.MOSTLY_INT_N and self._t_mac_raw_shape else None + self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype, raw_shape=raw_shape) def set_type(self): self.gguf_writer.add_type(gguf.GGUFType.MODEL) @@ -1700,6 +1794,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter ]): # transform weight into 1/0/-1 (in fp32) data_torch = self.weight_quant(data_torch) + if self.enable_t_mac: + # transform weight into T-MAC I2 format + from t_mac.model_utils import preprocess_for_t_mac + data = LazyTorchTensor.to_eager(data_torch).numpy() + scale = np.max(np.abs(data)) + w = np.round(data / scale + 2).astype(np.uint8) + data_torch = torch.from_numpy(preprocess_for_t_mac(self.kcfg_file, w, scale.reshape(1), bits=2)) + self._t_mac_bits = 2 + self._t_mac_raw_shape = w.shape yield (new_name, data_torch) @@ -4297,8 +4400,8 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16", - help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "int_n", "auto"], default="f16", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and int_n for int1/2/3/4, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( "--bigendian", action="store_true", @@ -4344,6 +4447,14 @@ def parse_args() -> argparse.Namespace: "--metadata", type=Path, help="Specify the path for an authorship metadata override file" ) + parser.add_argument( + "--enable-t-mac", action="store_true", + help="Enable T-MAC quantization format (disabled by default). Support GPTQ, GPTQv2, BitNet and BitDistiller." + ) + parser.add_argument( + "--kcfg", type=Path, + help="Specify the path for the T-MAC configuration file" + ) return parser.parse_args() @@ -4387,6 +4498,7 @@ def main() -> None: "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, "tq1_0": gguf.LlamaFileType.MOSTLY_TQ1_0, "tq2_0": gguf.LlamaFileType.MOSTLY_TQ2_0, + "int_n": gguf.LlamaFileType.MOSTLY_INT_N, "auto": gguf.LlamaFileType.GUESSED, } @@ -4420,7 +4532,8 @@ def main() -> None: metadata_override=args.metadata, model_name=args.model_name, split_max_tensors=args.split_max_tensors, split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, - small_first_shard=args.no_tensor_first_split) + small_first_shard=args.no_tensor_first_split, + enable_t_mac=args.enable_t_mac, kcfg_file=args.kcfg) if args.vocab_only: logger.info("Exporting model vocab...") diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index cfa6e3f70e4a3..8c586640bbf9c 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -166,6 +166,9 @@ option(GGML_SYCL "ggml: use SYCL" option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF) set (GGML_SYCL_TARGET "INTEL" CACHE STRING "ggml: sycl target device") +option(GGML_TMAC "ggml: use TMAC" OFF) +option(GGML_TMAC_SYSLIB "ggml: use TMAC system library" OFF) +option(GGML_TMAC_TVM_THREADPOOL "ggml: use TVM threadpool for TMAC" OFF) # extra artifacts option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE}) diff --git a/ggml/include/ggml-tmac.h b/ggml/include/ggml-tmac.h new file mode 100644 index 0000000000000..f79b674455dc6 --- /dev/null +++ b/ggml/include/ggml-tmac.h @@ -0,0 +1,42 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __ARM_NEON +#include +typedef float16_t tmac_float_type; +#else +typedef float tmac_float_type; +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +struct tmac_tensor_extra { + int lut_scales_size; + int scales_size; + int n_tile_num; + uint8_t * qweights; + tmac_float_type * scales; +}; + +GGML_API void ggml_tmac_init(void); +GGML_API void ggml_tmac_free(void); +// src0->type == Q4_0/IQ2_XXS/IQ3_XXS +// T-MAC currently only supports BitNet quantization or GPTQ-like quantization (only scales, without zeros) +// If use i-quantization gguf models, the results will be wrong +// TODO: add customized block types Q2_0/Q3_0 +GGML_API bool ggml_tmac_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst); +GGML_API size_t ggml_tmac_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst); +GGML_API void ggml_tmac_mul_mat_task_init(void * src1, void * qlut, void * lut_scales, void * lut_biases, int n, int k, int m, int bits); +GGML_API void ggml_tmac_mul_mat_task_compute(void * src0, void * scales, void * qlut, void * lut_scales, void * lut_biases, void * dst, int n, int k, int m, int bits); +GGML_API void ggml_tmac_transform_tensor(struct ggml_tensor * tensor); +GGML_API int ggml_tmac_get_type_bits(enum ggml_type type); +GGML_API void ggml_tmac_set_n_threads(int n_threads); +GGML_API size_t ggml_tmac_get_nbytes(const struct ggml_tensor * tensor); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 8a0bcbff8c61a..385330c7ae679 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -389,6 +389,10 @@ extern "C" { GGML_TYPE_Q4_0_8_8 = 33, GGML_TYPE_TQ1_0 = 34, GGML_TYPE_TQ2_0 = 35, + GGML_TYPE_I1 = 36, + GGML_TYPE_I2 = 37, + GGML_TYPE_I3 = 38, + GGML_TYPE_I4 = 39, GGML_TYPE_COUNT, }; diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 34b81bd7fdda1..778a105c855fd 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -870,6 +870,38 @@ if (GGML_KOMPUTE) endif() endif() +if (GGML_TMAC) + find_package(TMAC) + + if (TMAC_FOUND) + message(STATUS "TMAC found") + + list(APPEND GGML_CDEF_PUBLIC GGML_USE_TMAC) + + set(GGML_HEADERS_TMAC ../include/ggml-tmac.h) + set(GGML_SOURCES_TMAC ggml-tmac.cpp) + + link_directories(${TMAC_LIB_DIR}) + file(COPY ${TMAC_LIB_DIR}/kcfg.ini DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + # TODO: link t_mac_object when GGML_TMAC_SYSLIB + + if (GGML_TMAC_TVM_THREADPOOL) + add_compile_definitions(TMAC_USE_TVM_THREADPOOL) + set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} t_mac) + else() + if ((NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") OR + (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")) + message(FATAL_ERROR "Clang is required for T-MAC compilation") + endif() + + set(GGML_EXTRA_LIBS_PRIVATE ${GGML_EXTRA_LIBS_PRIVATE} t_mac_no_tvm) + set(GGML_SOURCES_TMAC ${GGML_SOURCES_TMAC} ${TMAC_KERNELS_SOURCE}) + endif() + else() + message(WARNING "TMAC not found") + endif() +endif() + if (GGML_CPU_HBM) find_library(memkind memkind REQUIRED) @@ -1170,6 +1202,26 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR # Raspberry Pi 3, 4, Zero 2 (32-bit) list(APPEND ARCH_FLAGS -mno-unaligned-access) endif() + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64" AND GGML_TMAC AND TMAC_FOUND) + # We need fullfp16 for T-MAC + # TODO: we need to simplify this logic through check_cxx_source_compiles or Presets? + check_cxx_source_compiles("#include \nint main() { int8x16_t _a, _b; int32x4_t _s = vmlaq_f32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8) + if (GGML_COMPILER_SUPPORT_MATMUL_INT8) + # Device with armv8.7a+ cpu, e.g., WSL on Surface Laptop 7 + # based on arm64-windows-llvm.cmake + list(APPEND ARCH_FLAGS -march=armv8.7-a+fp16 -fvectorize -ffp-model=fast -fno-finite-math-only) + add_compile_definitions(__ARM_FEATURE_MATMUL_INT8) + else () + # Jetson AGX Orin, Raspberry Pi 5 + list(APPEND ARCH_FLAGS -march=armv8.2a+fp16) + endif () + endif() + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ARM64" AND GGML_TMAC AND TMAC_FOUND) + # ARM Windows with LLVM clang GNU interface + # We need fullfp16 for T-MAC + # TODO: check_cxx_source_compiles + list(APPEND ARCH_FLAGS -march=armv8.2a+fp16) + endif() if (GGML_SVE) list(APPEND ARCH_FLAGS -march=armv8.6-a+sve) endif() @@ -1184,7 +1236,9 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW # TODO: improve, should not reference files from the parent folder include(../cmake/FindSIMD.cmake) endif () - if (GGML_AVX512) + # Can't use GGML_AVX512 with Clang for MSVC + # with error: conflicting types for '_m_prefetchw + if (GGML_AVX512 AND (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") AND (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")) list(APPEND ARCH_FLAGS /arch:AVX512) # MSVC has no compile-time flags enabling specific # AVX512 extensions, neither it defines the @@ -1388,6 +1442,7 @@ add_library(ggml ${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE} ${GGML_SOURCES_AMX} ${GGML_HEADERS_AMX} ${GGML_SOURCES_CANN} ${GGML_HEADERS_CANN} + ${GGML_SOURCES_TMAC} ${GGML_HEADERS_TMAC} ggml-aarch64.c ggml-aarch64.h ) diff --git a/ggml/src/ggml-cpu.c b/ggml/src/ggml-cpu.c index 4b8ffb629afbb..adc58ad0a783d 100644 --- a/ggml/src/ggml-cpu.c +++ b/ggml/src/ggml-cpu.c @@ -84,6 +84,10 @@ #include #endif +#if defined(GGML_USE_TMAC) +#include "ggml-tmac.h" +#endif + // floating point type used to accumulate sums typedef double ggml_float; @@ -7478,6 +7482,155 @@ static void ggml_compute_forward_mul_mat( UseGgmlGemm1:; #endif +// TODO: Refactor t-mac as ggml-backend, +// as ggml-blas.cpp has been moved to backend +#if defined(GGML_USE_TMAC) + if (ggml_tmac_can_mul_mat(src0, src1, dst)) { + const int bits = ggml_tmac_get_type_bits(type); + // src0: weight, ne00 = k, ne01 = n + // src1: activation, ne10 = k, ne11 = m + char * wdata = params->wdata; + + struct tmac_tensor_extra * wt = src0->extra; + char * cur_wdata = wdata; + tmac_float_type * tmac_f_ptr = wdata; + if (sizeof(tmac_float_type) == 2) { + cur_wdata = wdata + MAX(ne10, ne01) * ne11 * sizeof(tmac_float_type); + }; + int8_t * qlut = cur_wdata; + tmac_float_type * lut_scales = (tmac_float_type *) (qlut + ne10 * ne11 * 4); + tmac_float_type * lut_biases = (tmac_float_type *) (lut_scales + wt->lut_scales_size * ne11); + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + tmac_float_type * act_input; + if (sizeof(tmac_float_type) == 2) { + act_input = tmac_f_ptr; + } else { + act_input = src1->data; + } + for (int ine11 = ith; ine11 < ne11; ine11 += nth) { + if (sizeof(tmac_float_type) == 2) { + ggml_fp32_to_fp16_row((const float *) src1->data + ne10 * ine11, act_input + ne10 * ine11, ne10); + } + ggml_tmac_mul_mat_task_init(act_input + ne10 * ine11, + qlut + ne10 * ine11 * 4, + lut_scales + wt->lut_scales_size * ine11, + lut_biases + wt->lut_scales_size * ine11, + ne01, ne00, 1, bits); + } + + if (ith == 0) { + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + atomic_store_explicit(¶ms->threadpool->current_chunk, nth, memory_order_relaxed); + } + + ggml_barrier(params->threadpool); + + tmac_float_type * act_output; + if (sizeof(tmac_float_type) == 2) { + act_output = tmac_f_ptr; + } else { + act_output = dst->data; + } +// TODO: remove TVM threadpool if ensuring unused +#if defined(TMAC_USE_TVM_THREADPOOL) + if (ith != 0) { + return; + } + // TODO: schedule ne11(m) in T-MAC + for (int ine11 = 0; ine11 < ne11; ine11++) { + const int qlut_offset = ne10 * ine11 * 4; + const int lut_scales_offset = wt->lut_scales_size * ine11; + const int dst_offset = ne0 * ine11; + + ggml_tmac_mul_mat_task_compute(wt->qweights, + wt->scales, + qlut + qlut_offset, + lut_scales + lut_scales_offset, + lut_biases + lut_scales_offset, + act_output + dst_offset, + ne01, ne00, 1, bits); + } + if (sizeof(tmac_float_type) == 2) { + ggml_fp16_to_fp32_row(tmac_f_ptr, dst->data, ne00 * ne01); + } +#else // #if defined(TMAC_USE_TVM_THREADPOOL) + const int n_tile_num = wt->n_tile_num; + // Currently, T-MAC requires ne0 devisible by n_tile_num + GGML_ASSERT(ne0 % n_tile_num == 0); + + const int64_t w_size = ne00 * ne01 * bits / 8; + const int64_t w_chunk_size = w_size / n_tile_num; + + const int64_t nr0 = ne0; + const int64_t nr1 = ne1 * ne2 * ne3; + + // Adopt the same style with current llama.cpp impl + // But different chunk size for 0/1 dim. + // No scrap. + const int chunk_size0 = ne0 / n_tile_num; + const int chunk_size1 = 8; // TODO: tune in T-MAC + + // nchunk0 == n_tile_num + int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0; + int64_t nchunk1 = (nr1 + chunk_size1 - 1) / chunk_size1; + + int64_t dr0 = chunk_size0; + int64_t dr1 = chunk_size1; + + // Rechunk + if ((nchunk1 == 1) && (nchunk0 > nth * 4)) { + // dr0 should be divisible by chunk_size0 + dr0 = (ne0 / (nth * 4) / chunk_size0) * chunk_size0; + nchunk0 = (nr0 + dr0 - 1) / dr0; + } + + int current_chunk = ith; + + while (current_chunk < nchunk0 * nchunk1) { + const int64_t ith0 = current_chunk % nchunk0; + const int64_t ith1 = current_chunk / nchunk0; + + const int64_t ir0_start = dr0 * ith0; + const int64_t ir0_end = MIN(ir0_start + dr0, nr0); + + const int64_t ir1_start = dr1 * ith1; + const int64_t ir1_end = MIN(ir1_start + dr1, nr1); + + // inline ggml_compute_forward_mul_mat_one_chunk here for simplicity + for (int64_t ichunk0 = ir0_start / chunk_size0; ichunk0 < ir0_end / chunk_size0; ichunk0++) { + const int64_t w_offset = ichunk0 * w_chunk_size; + const int64_t scales_offset = ichunk0 * wt->scales_size / n_tile_num; + + for (int64_t ine11 = ir1_start; ine11 < ir1_end; ine11++) { + const int64_t qlut_offset = ne10 * ine11 * 4; + const int64_t lut_scales_offset = wt->lut_scales_size * ine11; + const int64_t dst_offset = ne0 * ine11 + ichunk0 * chunk_size0; + + ggml_tmac_mul_mat_task_compute(wt->qweights + w_offset, + wt->scales + scales_offset, + qlut + qlut_offset, + lut_scales + lut_scales_offset, + lut_biases + lut_scales_offset, + act_output + dst_offset, + chunk_size0, ne00, 1, bits); + if (sizeof(tmac_float_type) == 2) { + ggml_fp16_to_fp32_row(act_output + dst_offset, (float *) dst->data + dst_offset, chunk_size0); + } + } + } + + if (nth >= nchunk0 * nchunk1) { + break; + } + + current_chunk = atomic_fetch_add_explicit(¶ms->threadpool->current_chunk, 1, memory_order_relaxed); + } +#endif // #if defined(TMAC_USE_TVM_THREADPOOL) + return; + } // if (ggml_tmac_can_mul_mat(src0, src1, dst)) +#endif // #if defined(GGML_USE_TMAC) + if (src1->type != vec_dot_type) { char * wdata = params->wdata; @@ -9123,6 +9276,10 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_I32: case GGML_TYPE_I64: case GGML_TYPE_F64: + case GGML_TYPE_I1: + case GGML_TYPE_I2: + case GGML_TYPE_I3: + case GGML_TYPE_I4: case GGML_TYPE_COUNT: { GGML_ABORT("fatal error"); @@ -13172,6 +13329,11 @@ struct ggml_cplan ggml_graph_plan( { const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type; +#if defined(GGML_USE_TMAC) + if (ggml_tmac_can_mul_mat(node->src[0], node->src[1], node)) { + cur = ggml_tmac_mul_mat_get_wsize(node->src[0], node->src[1], node); + } else +#endif if (node->src[1]->type != vec_dot_type) { cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1])); } @@ -13708,6 +13870,10 @@ void ggml_cpu_init(void) { ggml_init_arm_arch_features(); #endif +#if defined(GGML_USE_TMAC) + ggml_tmac_init(); +#endif + is_first_call = false; } diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 7aa6dce8907f5..8d828b8c0180b 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15741,6 +15741,12 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_I64: // nothing to validate break; + case GGML_TYPE_I1: + case GGML_TYPE_I2: + case GGML_TYPE_I3: + case GGML_TYPE_I4: + // nothing to validate + break; default: { fprintf(stderr, "%s: invalid type %d\n", __func__, type); diff --git a/ggml/src/ggml-tmac.cpp b/ggml/src/ggml-tmac.cpp new file mode 100644 index 0000000000000..01552e058a7cb --- /dev/null +++ b/ggml/src/ggml-tmac.cpp @@ -0,0 +1,398 @@ +#include +#include + +#include "ggml-tmac.h" +#include "ggml-quants.h" + +#include "t-mac/tmac_gemm_wrapper.h" + +#define GGML_TMAC_MAX_NODES 8192 + +static bool initialized = false; + +static TMAC::TMACGeMMWrapper * wrapper = nullptr; + +static tmac_tensor_extra * tmac_tensor_extras = nullptr; + +static size_t tmac_tensor_extras_index = 0; + +static void * aligned_malloc(size_t size) { +#if defined(_WIN32) + return _aligned_malloc(size, TMAC::kAllocAlignment); +#else + void * ptr = nullptr; + posix_memalign(&ptr, TMAC::kAllocAlignment, size); + return ptr; +#endif +} + +static void aligned_free(void * ptr) { +#if defined(_WIN32) + _aligned_free(ptr); +#else + free(ptr); +#endif +} + +void ggml_tmac_init(void) { + LOG(INFO) << "ggml_tmac_init"; + + if (initialized) { + return; + } + initialized = true; + + if (wrapper == nullptr) { + wrapper = new TMAC::TMACGeMMWrapper(); + } + if (tmac_tensor_extras == nullptr) { + tmac_tensor_extras = new tmac_tensor_extra[GGML_TMAC_MAX_NODES]; + } + tmac_tensor_extras_index = 0; +} + +void ggml_tmac_free(void) { + LOG(INFO) << "ggml_tmac_free"; + + if (!initialized) { + return; + } + initialized = false; + + delete wrapper; + wrapper = nullptr; + for (size_t i = 0; i < tmac_tensor_extras_index; i++) { + // aligned_free(tmac_tensor_extras[i].qweights); + // aligned_free(tmac_tensor_extras[i].scales); + } + delete[] tmac_tensor_extras; + tmac_tensor_extras = nullptr; +} + +static bool is_type_supported(enum ggml_type type) { + if (//type == GGML_TYPE_Q4_0 || + type == GGML_TYPE_I1 || + type == GGML_TYPE_I2 || + type == GGML_TYPE_I3 || + type == GGML_TYPE_I4) { + return true; + } else { + return false; + } +} + +static bool do_permutate(enum ggml_type type) { + if (type == GGML_TYPE_I1 || + type == GGML_TYPE_I2 || + type == GGML_TYPE_I3 || + type == GGML_TYPE_I4) { + // Add additional args to decide if permuted I2 or naive I2 + return false; + } else { + return true; + } +} + +struct BlockQ40TypeAccessor { + using block_t = block_q4_0; + + static constexpr int BITS = 4; + static constexpr int SIMD_LEN = 16; + static constexpr int group_size = (sizeof(block_t) - sizeof(ggml_fp16_t)) * 8 / BITS; + static constexpr int simd_n_elem = SIMD_LEN * 8 / BITS; + + static uint8_t get_q(const void * data, int idx) { + const uint8_t * qs = (const uint8_t *) ((((const block_t *) data)[idx / group_size]).qs); + int internal_idx = idx % group_size; + const uint8_t * simd_qs = qs + internal_idx / simd_n_elem * SIMD_LEN; + int simd_idx = internal_idx % simd_n_elem; + return simd_qs[simd_idx % SIMD_LEN] >> (simd_idx / SIMD_LEN * BITS); + } + + static tmac_float_type get_scale(const void * data, int idx) { + ggml_fp16_t d = ((const block_t *) data)[idx / group_size].d; + if (sizeof(tmac_float_type) == 2) { + tmac_float_type * fp16dp = reinterpret_cast(&d); + return *fp16dp; + } else { + return ggml_fp16_to_fp32(((const block_t *) data)[idx / group_size].d); + } + } +}; + +struct BlockI2TypeAccessor { + static constexpr int BITS = 2; + static constexpr int n_elem = 8 / BITS; + + static uint8_t get_q(const void * data, int idx) { + const uint8_t * qs = (const uint8_t *) data; + int elem_idx = idx % n_elem; + return qs[idx / n_elem] >> (elem_idx * BITS); + } + + static tmac_float_type get_scale(const void * data, int idx, int group_size) { + const float * ss = (const float *) data; + float s = ss[idx / group_size]; + return (tmac_float_type) s; + } +}; + +bool ggml_tmac_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) { + if ((is_type_supported(src0->type)) && + src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && + src0->backend == GGML_BACKEND_TYPE_CPU) { + return true; + } + return false; +} + +size_t ggml_tmac_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) { + const size_t ne01 = src0->ne[1]; + const size_t ne10 = src1->ne[0]; + const size_t ne11 = src1->ne[1]; + const int bits = ggml_tmac_get_type_bits(src0->type); + + TMAC::TMACGeMMConfig kcfg = wrapper->get_kcfg(ne01, ne10, 1, bits); + + size_t wsize = ne10 * ne11 * 4 * sizeof(int8_t) + kcfg.lut_scales_size * ne11 * 2 * sizeof(tmac_float_type); + if (sizeof(tmac_float_type) == 2) { + // Need fp32 to fp16 conversion + wsize += std::max(ne10, ne01) * ne11 * sizeof(tmac_float_type); + } + wsize = ((wsize - 1) / TMAC::kAllocAlignment + 1) * TMAC::kAllocAlignment; + return wsize; +} + +// m = batch_size +// n = output_dim +void ggml_tmac_mul_mat_task_init(void * src1, void * qlut, void * lut_scales, void * lut_biases, int n, int k, int m, int bits) { + // t-mac llama.cpp n and m swapped + wrapper->llama_cpp_init(src1, qlut, lut_scales, lut_biases, n, k, m, bits); +} + +void ggml_tmac_mul_mat_task_compute(void * src0, void * scales, void * qlut, void * lut_scales, void * lut_biases, void * dst, int n, int k, int m, int bits) { + wrapper->llama_cpp_compute(src0, scales, qlut, lut_scales, lut_biases, dst, n, k, m, bits); +} + +size_t ggml_tmac_get_nbytes(const struct ggml_tensor * tensor) { + const int bits = ggml_tmac_get_type_bits(tensor->type); + + int k = tensor->ne[0]; + int m = tensor->ne[1]; // `n` in llama.cpp + + TMAC::TMACGeMMConfig kcfg = wrapper->get_kcfg(m, k, 1, bits); + // Currently, I2 always uses float to store scales or zero points + size_t nbytes = k * m / 8 * bits + kcfg.scales_size * sizeof(float); + return nbytes; +} + +void ggml_tmac_transform_tensor(struct ggml_tensor * tensor) { + if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { + return; + } + + const int bits = ggml_tmac_get_type_bits(tensor->type); + const int g = 4; + const int ngroups_per_elem = 2; + + int k = tensor->ne[0]; + int m = tensor->ne[1]; // `n` in llama.cpp + + TMAC::TMACGeMMConfig kcfg = wrapper->get_kcfg(m, k, 1, bits); + const int bm = kcfg.bm; + const int simd_n_in = kcfg.simd_n_in; + const int simd_n_out = kcfg.simd_n_out; + const int kfactor = kcfg.kfactor; + const int group_size = kcfg.group_size; // could be different from block size in llama.cpp + const int lut_scales_size = kcfg.lut_scales_size; + const int scales_size = kcfg.scales_size; + const int n_tile_num = kcfg.n_tile_num; + DLOG(INFO) << "Transforming tensor: " << tensor->name << " (m: " << m << ", k: " << k << ", bits: " << bits << ")"; + DLOG(INFO) << "kcfg (bm=" << bm << ", simd_n_in=" << simd_n_in << ", simd_n_out=" << simd_n_out << ", kfactor=" << kfactor + << ", group_size=" << group_size << ", lut_scales_size=" << lut_scales_size << ", scales_size=" << scales_size << ", n_tile_num=" << n_tile_num << ")"; + if (bm == 0) { + // Instead of fatal error, try to avoid using t-mac? + LOG(FATAL) << "Failed to find kcfg. Abort transforming"; + return; + } + const int mgroup = ngroups_per_elem * simd_n_in; + m = m * bits; + + uint8_t * qweights; + tmac_float_type * scales; + + scales = (tmac_float_type *) aligned_malloc(scales_size * sizeof(tmac_float_type)); + if (do_permutate(tensor->type)) { + qweights = (uint8_t *) aligned_malloc(k * m / 8); + } else { + qweights = (uint8_t *) tensor->data; + float * i2_scales = (float * )(qweights + k * m / 8); + for (int i = 0; i < scales_size; i++) { + scales[i] = (tmac_float_type) i2_scales[i]; + } + } + + tensor->extra = tmac_tensor_extras + tmac_tensor_extras_index; + tmac_tensor_extras[tmac_tensor_extras_index++] = { + /* .lut_scales_size = */ lut_scales_size, + /* .scales_size = */ scales_size, + /* .n_tile_num = */ n_tile_num, + /* .qweights = */ qweights, + /* .scales = */ scales + }; + + if (do_permutate(tensor->type)) { +// for fast testing +// #define TMAC_EMPTY_WEIGHTS +#ifndef TMAC_EMPTY_WEIGHTS + // TODO: optimize to accelerate weights loading + uint8_t * buf1 = new uint8_t[m * k]; + uint8_t * buf2 = new uint8_t[m * k / g]; + + // # (M // bits, K, bits) + // w = np.stack([(w >> ib) & 1 for ib in range(bits)], axis=-1) + for (int im = 0; im < m / bits; im++) { + for (int ik = 0; ik < k; ik++) { + for (int ib = 0; ib < bits; ib++) { + uint8_t v; + if (tensor->type == GGML_TYPE_Q4_0) { + v = BlockQ40TypeAccessor::get_q(tensor->data, im * k + ik); + } else if (tensor->type == GGML_TYPE_I2) { + v = BlockI2TypeAccessor::get_q(tensor->data, im * k + ik); + } + buf1[im * k * bits + ik * bits + ib] = (v >> ib) & 1; + } + } + } + + // # (M // bits, K, bits) -> (M // bits, bits, K) -> (M // bits, bits, K // g, g) -> (M // bits, bits, K // g) + // w = w.transpose(0, 2, 1).reshape(M // bits, bits, K // g, g) + // w = sum([(w[:, :, :, ig] << ig) for ig in range(g)]) + memset(buf2, 0, m * k / g); + for (int im = 0; im < m / bits; im++) { + for (int ik = 0; ik < k; ik++) { + for (int ib = 0; ib < bits; ib++) { + int new_im = im; + int new_ib = ib; + int new_ik = ik / g; + int new_ig = ik % g; + buf2[new_im * bits * k / g + new_ib * k / g + new_ik] += buf1[im * k * bits + ik * bits + ib] << new_ig; + } + } + } + + // # 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31 + // # for bits=3 + // # bit0: [0, 8), bit1: [8, 16), bit2: [16, 24), bit0: [24, 32) + // # (M // bits // simd_n_float16, bits, simd_n_float16, K // g) + // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) + // mgroup = ngroups_per_elem * simd_n_in + // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) + // # 0 1 2 3 4 5 + // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) + // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)]) + memset(qweights, 0, m * k / g / ngroups_per_elem); + for (int im = 0; im < m / bits; im++) { + for (int ib = 0; ib < bits; ib++) { + for (int ik = 0; ik < k / g; ik++) { + int new_im = im / simd_n_out; + int new_isno = im % simd_n_out; + int new_ib = ib; + int new_ik = ik; + // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) + int new_idx = new_im * bits * simd_n_out * k / g + new_ib * simd_n_out * k / g + new_isno * k / g + new_ik; + // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) + int nb2 = k / g; + int nb1 = simd_n_in * nb2; + int nb0 = ngroups_per_elem * nb1; + new_im = new_idx / nb0; + int new_ing = (new_idx % nb0) / nb1; + int new_isni = (new_idx % nb1) / nb2; + new_ik = (new_idx % nb2); + new_idx = new_im * ngroups_per_elem * simd_n_in * k / g + new_isni * ngroups_per_elem * k / g + new_ing * k / g + new_ik; + // # 0 1 2 3 4 5 + // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) + int nb4 = kfactor; + int nb3 = k / g / kfactor * nb4; + nb2 = ngroups_per_elem * nb3; + nb1 = simd_n_in * nb2; + nb0 = bm / mgroup * nb1; + new_im = new_idx / nb0; + int new_ibm = (new_idx % nb0) / nb1; + new_isni = (new_idx % nb1) / nb2; + new_ing = (new_idx % nb2) / nb3; + new_ik = (new_idx % nb3) / nb4; + int new_ikf = (new_idx % nb4); + new_idx = new_im * k / g / kfactor * bm / mgroup * kfactor * simd_n_in * ngroups_per_elem + + new_ik * bm / mgroup * kfactor * simd_n_in * ngroups_per_elem + + new_ibm * kfactor * simd_n_in * ngroups_per_elem + + new_ikf * simd_n_in * ngroups_per_elem + + new_isni * ngroups_per_elem + + new_ing; + new_idx = new_idx / ngroups_per_elem; + // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)]) + qweights[new_idx] += buf2[im * bits * k / g + ib * k / g + ik] << (new_ing * g); + } + } + } + + const float * i2_scales = (const float * ) ((const uint8_t *) tensor->data + k * m / 8); + if (scales_size < m / bits) { // BitNet-like scale (m_groups,) + for (int i = 0; i < scales_size; i++) { + scales[i] = (tmac_float_type) i2_scales[i]; + } + } else { // GPTQ-like scale (m / bits, k / group_size) + GGML_ASSERT(scales_size == m / bits * k / group_size); + // scales = scales.reshape(M // bm, bm // bits, K // group_size).transpose(0, 2, 1) + for (int im = 0; im < m / bits; im += 1) { + for (int ik = 0; ik < k; ik += group_size) { + tmac_float_type scale; + int idx = im * k + ik; + if (tensor->type == GGML_TYPE_Q4_0) { + scale = BlockQ40TypeAccessor::get_scale(tensor->data, idx); + } else if (tensor->type == GGML_TYPE_I2) { + scale = BlockI2TypeAccessor::get_scale(i2_scales, idx, group_size); + } + int new_idx; + idx = idx / group_size; + int new_im = idx / (bm / bits * k / group_size); + int new_ibm = (idx % (bm / bits * k / group_size)) / (k / group_size); + int new_ik = (idx % (k / group_size)); + new_idx = new_im * k / group_size * bm / bits + new_ik * bm / bits + new_ibm; + scales[new_idx] = scale; + } + } + } + + delete[] buf1; + delete[] buf2; +#else + memset(qweights, 0x88, k * m / 8); + for (int i = 0; i < scales_size; i++) { + scales[i] = 1.0f; + } +#endif + } // if (do_permutate(tensor->type)) +} + +int ggml_tmac_get_type_bits(enum ggml_type type) { + switch (type) { + case GGML_TYPE_I1: + return 1; + case GGML_TYPE_I2: + return 2; + case GGML_TYPE_I3: + return 3; + case GGML_TYPE_I4: + return 4; + case GGML_TYPE_Q4_0: + return 4; + default: + return 0; + } +} + +void ggml_tmac_set_n_threads(int n_threads) { + wrapper->set_num_threads(n_threads); +} diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7dc3340a1e749..823abbe175644 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -541,6 +541,42 @@ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc); static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { + [GGML_TYPE_I1] = { + .type_name = "i1", + .blck_size = 8, + .type_size = sizeof(int8_t), + .is_quantized = false, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, + [GGML_TYPE_I2] = { + .type_name = "i2", + .blck_size = 4, + .type_size = sizeof(int8_t), + .is_quantized = false, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, + [GGML_TYPE_I3] = { + .type_name = "i3", + .blck_size = 2, + .type_size = sizeof(int8_t), + .is_quantized = false, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, + [GGML_TYPE_I4] = { + .type_name = "i4", + .blck_size = 2, + .type_size = sizeof(int8_t), + .is_quantized = false, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, [GGML_TYPE_I8] = { .type_name = "i8", .blck_size = 1, @@ -1161,6 +1197,14 @@ size_t ggml_nbytes(const struct ggml_tensor * tensor) { nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; } } +#if defined(GGML_USE_TMAC) + if(tensor->type == GGML_TYPE_I1 || + tensor->type == GGML_TYPE_I2 || + tensor->type == GGML_TYPE_I3 || + tensor->type == GGML_TYPE_I4){ + nbytes = ggml_tmac_get_nbytes(tensor); + } +#endif return nbytes; } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 7ab08b036e527..ed52e0c2100fb 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1404,6 +1404,10 @@ class GGMLQuantizationType(IntEnum): Q4_0_8_8 = 33 TQ1_0 = 34 TQ2_0 = 35 + I1 = 36 + I2 = 37 + I3 = 38 + I4 = 39 # TODO: add GGMLFileType from ggml_ftype in ggml.h @@ -1450,6 +1454,7 @@ class LlamaFileType(IntEnum): MOSTLY_Q4_0_8_8 = 35 # except 1d tensors MOSTLY_TQ1_0 = 36 # except 1d tensors MOSTLY_TQ2_0 = 37 # except 1d tensors + MOSTLY_INT_N = 38 # except 1d tensors GUESSED = 1024 # not specified in the model file @@ -1528,6 +1533,15 @@ def get_type(val: Any) -> GGUFValueType: GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16), GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13), GGMLQuantizationType.TQ2_0: (256, 2 + 64), + # Currently, we use tricks here + # - The block size doesn't include scales or zero_points as group_size is changeable + # - So the size is slightly smaller than the real size + # - The n_bytes in gguf_reader.py is thus inaccurate + # - During inference, the accurate nbytes info will be known through ggml_tmac_get_nbytes + GGMLQuantizationType.I1: (8, 1), + GGMLQuantizationType.I2: (4, 1), + GGMLQuantizationType.I3: (8, 3), + GGMLQuantizationType.I4: (2, 1), } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 0d8d8a0b087e9..e2b4838bf8813 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -345,8 +345,10 @@ def add_tensor_info( dtype = GGMLQuantizationType.I32 elif tensor_dtype == np.int64: dtype = GGMLQuantizationType.I64 + elif tensor_dtype == np.uint8: + dtype = GGMLQuantizationType.I2 else: - raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now") + raise ValueError("Only F16, F32, F64, I8, I16, I32, I64, I2 tensors are supported for now") else: dtype = raw_dtype if tensor_dtype == np.uint8: diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py index 3c8ba82e19d3d..6f2c5e10d881f 100644 --- a/gguf-py/gguf/quants.py +++ b/gguf-py/gguf/quants.py @@ -60,6 +60,15 @@ def quantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: return data.astype(np.float16, copy=False) elif (q := _type_traits.get(qtype)) is not None: return q.quantize(data) + # Do nothing for I1/2/3/4, as they are already quantized + elif qtype == GGMLQuantizationType.I1: + return data + elif qtype == GGMLQuantizationType.I2: + return data + elif qtype == GGMLQuantizationType.I3: + return data + elif qtype == GGMLQuantizationType.I4: + return data else: raise NotImplementedError(f"Quantization for {qtype.name} is not yet implemented") diff --git a/include/llama.h b/include/llama.h index ccb48f73cef5c..825d282d35495 100644 --- a/include/llama.h +++ b/include/llama.h @@ -176,6 +176,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors + LLAMA_FTYPE_MOSTLY_INT_N = 38, LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama.cpp b/src/llama.cpp index 3e563d811b77c..712d1d795bd5d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9,6 +9,10 @@ #include "ggml-backend.h" #include "ggml-cpp.h" +#ifdef GGML_USE_TMAC +# include "ggml-tmac.h" +#endif + // TODO: replace with ggml API call #define QK_K 256 @@ -4434,6 +4438,10 @@ struct llama_model_loader { case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break; case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break; case GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break; + case GGML_TYPE_I1: ftype = LLAMA_FTYPE_MOSTLY_INT_N; break; + case GGML_TYPE_I2: ftype = LLAMA_FTYPE_MOSTLY_INT_N; break; + case GGML_TYPE_I3: ftype = LLAMA_FTYPE_MOSTLY_INT_N; break; + case GGML_TYPE_I4: ftype = LLAMA_FTYPE_MOSTLY_INT_N; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); @@ -4775,7 +4783,9 @@ struct llama_model_loader { void done_getting_tensors() const { if (n_created != n_tensors) { - throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + // Zero bias in some HuggingFace models will cause n_tensors mismatch + // Consider removing zero bias in convert_hf_to_gguf.py? + // throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); } } @@ -5032,6 +5042,11 @@ struct llama_model_loader { } size_done += n_size; + +#if defined(GGML_USE_TMAC) + // Do pre-transformation to reduce first-run latency + ggml_tmac_transform_tensor(cur); +#endif } // free temporary resources used for async uploads @@ -5171,6 +5186,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; + case LLAMA_FTYPE_MOSTLY_INT_N: return "INT_N"; case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; @@ -17183,6 +17199,13 @@ static void llama_graph_compute( ggml_cgraph * gf, int n_threads, ggml_threadpool * threadpool) { +#ifdef GGML_USE_TMAC + #ifdef TMAC_USE_TVM_THREADPOOL + ggml_tmac_set_n_threads(n_threads); + n_threads = 1; + #endif +#endif + if (lctx.backend_cpu != nullptr) { ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool); ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data); @@ -18442,6 +18465,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; + case LLAMA_FTYPE_MOSTLY_INT_N:default_type = GGML_TYPE_I2; break; // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: @@ -18747,6 +18771,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) { new_type = params->output_tensor_type; } + if (tensor->type == GGML_TYPE_I1 || + tensor->type == GGML_TYPE_I2 || + tensor->type == GGML_TYPE_I3 || + tensor->type == GGML_TYPE_I4) { + // no need quantize for iN + new_type = tensor->type; + } // If we've decided to quantize to the same type the tensor is already // in then there's nothing to do. From 94502e44a7f414a7efe03d35eff9ebf76ef6fdad Mon Sep 17 00:00:00 2001 From: Qingtao Li Date: Fri, 11 Oct 2024 14:40:58 +0800 Subject: [PATCH 02/12] Fix a Cmake variable fault. --- ggml/src/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 778a105c855fd..7850fe0b5294d 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -887,7 +887,7 @@ if (GGML_TMAC) if (GGML_TMAC_TVM_THREADPOOL) add_compile_definitions(TMAC_USE_TVM_THREADPOOL) - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} t_mac) + set(GGML_EXTRA_LIBS_PRIVATE ${GGML_EXTRA_LIBS_PRIVATE} t_mac) else() if ((NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") OR (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")) From f673699460167e0d3c97f1395057218ad8e3e2ff Mon Sep 17 00:00:00 2001 From: Qingtao Li Date: Fri, 11 Oct 2024 14:41:31 +0800 Subject: [PATCH 03/12] Remove is_lora in convert_hf_to_gguf, which is removed in master. --- convert_hf_to_gguf.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6f31fc87815b3..a3d35c2db3c04 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -65,7 +65,6 @@ class Model: model_name: str | None metadata_override: Path | None dir_model_card: Path - is_lora: bool enable_t_mac: bool kcfg_file: Path | None @@ -76,7 +75,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, - small_first_shard: bool = False, hparams: dict[str, Any] | None = None, is_lora: bool = False, + small_first_shard: bool = False, hparams: dict[str, Any] | None = None, enable_t_mac: bool = False, kcfg_file: Path | None = None): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") @@ -99,7 +98,6 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.metadata_override = metadata_override self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py - self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py self.enable_t_mac = enable_t_mac self.kcfg_file = kcfg_file # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type From dfac0c4b3e6e575284b273c68df995023250379c Mon Sep 17 00:00:00 2001 From: Qingtao Li Date: Fri, 11 Oct 2024 14:43:14 +0800 Subject: [PATCH 04/12] Remove uint8 branch in gguf_writer. --- gguf-py/gguf/gguf_writer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index e2b4838bf8813..0d8d8a0b087e9 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -345,10 +345,8 @@ def add_tensor_info( dtype = GGMLQuantizationType.I32 elif tensor_dtype == np.int64: dtype = GGMLQuantizationType.I64 - elif tensor_dtype == np.uint8: - dtype = GGMLQuantizationType.I2 else: - raise ValueError("Only F16, F32, F64, I8, I16, I32, I64, I2 tensors are supported for now") + raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now") else: dtype = raw_dtype if tensor_dtype == np.uint8: From f64c7680550ebf4dd0453013524cc1054b311d17 Mon Sep 17 00:00:00 2001 From: Qingtao Li Date: Sat, 12 Oct 2024 12:48:45 +0800 Subject: [PATCH 05/12] Restore n_tensor check. --- convert_hf_to_gguf.py | 6 ++++++ src/llama.cpp | 4 +--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a3d35c2db3c04..45a3cb707713a 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -367,6 +367,12 @@ def prepare_tensors(self): break for new_name, data_torch in (self._modify_tensors(data_torch, name, bid)): + # Some GPTQ models have empty bias tensors which are not in the model architecture. + # These tensors will cause tensor number check to fail, so we have to skip them. + if new_name.endswith(".bias") and np.all(LazyTorchTensor.to_eager(data_torch).numpy() == 0): + logger.info(f"Skipping empty bias tensor: {new_name}") + continue + data = data_torch.squeeze().numpy() # if data ends up empty, it means data_torch was a scalar tensor -> restore diff --git a/src/llama.cpp b/src/llama.cpp index 712d1d795bd5d..96699f2608630 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4783,9 +4783,7 @@ struct llama_model_loader { void done_getting_tensors() const { if (n_created != n_tensors) { - // Zero bias in some HuggingFace models will cause n_tensors mismatch - // Consider removing zero bias in convert_hf_to_gguf.py? - // throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); } } From 6bb4acae7cae07fb8fe5c6bcfe4b7c71664493c8 Mon Sep 17 00:00:00 2001 From: Qingtao Li Date: Sat, 12 Oct 2024 13:27:30 +0800 Subject: [PATCH 06/12] Remove unused code. --- src/llama.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index 96699f2608630..7529f3e04c33a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18463,7 +18463,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; - case LLAMA_FTYPE_MOSTLY_INT_N:default_type = GGML_TYPE_I2; break; // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: From b2662907005628d25dc1d14253daccbb851ba195 Mon Sep 17 00:00:00 2001 From: kalineid Date: Mon, 14 Oct 2024 17:31:48 +0800 Subject: [PATCH 07/12] [llama.cpp] update convert_hf_to_gguf.py --- convert_hf_to_gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 45a3cb707713a..f44c4ca6d1dc7 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -331,10 +331,10 @@ def _modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Ite data_torch = torch.from_numpy(data.reshape(old_shape)) if self.ftype == gguf.LlamaFileType.MOSTLY_F16: data_torch = data_torch.to(torch.float16) - else: - return self.modify_tensors(data_torch, name, bid) - return [(self.map_tensor_name(name), data_torch)] + return [(self.map_tensor_name(name), data_torch)] + + return self.modify_tensors(data_torch, name, bid) def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid, n_dims # unused From e86c69df8b5a2157f6a8909a4395b88456eb0fcd Mon Sep 17 00:00:00 2001 From: Qingtao Li Date: Wed, 23 Oct 2024 12:28:25 +0800 Subject: [PATCH 08/12] [Feat] Support TQ1_0 and TQ2_0 with T-MAC. - Adding support for new tensor types `GGML_TYPE_TQ1_0` and `GGML_TYPE_TQ2_0` - Handling the case when the kcfg is not found for certain tensors (`token_embd.weight` and `output.weight`), displaying a warning message instead of a fatal error --- ggml/src/ggml-tmac.cpp | 153 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 139 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-tmac.cpp b/ggml/src/ggml-tmac.cpp index 01552e058a7cb..c5d20640893f7 100644 --- a/ggml/src/ggml-tmac.cpp +++ b/ggml/src/ggml-tmac.cpp @@ -70,11 +70,13 @@ void ggml_tmac_free(void) { } static bool is_type_supported(enum ggml_type type) { - if (//type == GGML_TYPE_Q4_0 || + if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_I1 || type == GGML_TYPE_I2 || type == GGML_TYPE_I3 || - type == GGML_TYPE_I4) { + type == GGML_TYPE_I4 || + type == GGML_TYPE_TQ1_0 || + type == GGML_TYPE_TQ2_0) { return true; } else { return false; @@ -115,7 +117,7 @@ struct BlockQ40TypeAccessor { tmac_float_type * fp16dp = reinterpret_cast(&d); return *fp16dp; } else { - return ggml_fp16_to_fp32(((const block_t *) data)[idx / group_size].d); + return ggml_fp16_to_fp32(d); } } }; @@ -137,11 +139,109 @@ struct BlockI2TypeAccessor { } }; +struct BlockTQ10TypeAccessor { + using block_t = block_tq1_0; + + static constexpr int elements_qs = 5; // 5 elements per byte + static constexpr int elements_qh = 4; // 4 elements per byte + static constexpr int BITS = 2; + static constexpr int group_size_qs = sizeof(((block_t *)0)->qs) * elements_qs; + static constexpr int group_size_qh = sizeof(((block_t *)0)->qh) * elements_qh; + static constexpr int group_size = group_size_qs + group_size_qh; + static constexpr int SIMD_LEN_qs_1 = 32; + static constexpr int SIMD_LEN_qs_2 = 16; + static constexpr int SIMD_LEN_qh = 4; + static constexpr int simd_n_elem_qs_1 = SIMD_LEN_qs_1 * elements_qs; // 160 + static constexpr int simd_n_elem_qs_2 = SIMD_LEN_qs_2 * elements_qs; // 80 + static constexpr int simd_n_elem_qh = SIMD_LEN_qh * elements_qh; // 16 + + static constexpr uint8_t pow3[5] = {1, 3, 9, 27, 81}; + + static uint8_t get_q(const void * data, int idx) { + const uint8_t * qs = (const uint8_t *) ((((const block_t *) data)[idx / group_size]).qs); + uint8_t cur_qs; + uint8_t trit; + int internal_idx = idx % group_size; + + if (internal_idx < simd_n_elem_qs_1) { + const int internal_offset = 0; + const uint8_t * simd_qs = qs + internal_offset; + int simd_idx = internal_idx; + int simd_byte = simd_idx % SIMD_LEN_qs_1; + int simd_trit = simd_idx / SIMD_LEN_qs_1; + + cur_qs = simd_qs[simd_byte] * pow3[simd_trit]; + trit = ((uint16_t) cur_qs * 3) >> 8; + } + else if (internal_idx < simd_n_elem_qs_1 + simd_n_elem_qs_2) { + const int internal_offset = SIMD_LEN_qs_1; + const uint8_t * simd_qs = qs + internal_offset; + int simd_idx = internal_idx - simd_n_elem_qs_1; + int simd_byte = simd_idx % SIMD_LEN_qs_2; + int simd_trit = simd_idx / SIMD_LEN_qs_2; + + cur_qs = simd_qs[simd_byte] * pow3[simd_trit]; + trit = ((uint16_t) cur_qs * 3) >> 8; + } + else { + const int internal_offset = SIMD_LEN_qs_1 + SIMD_LEN_qs_2; + const uint8_t * simd_qs = qs + internal_offset; + int simd_idx = internal_idx - simd_n_elem_qs_1 - simd_n_elem_qs_2; + int simd_byte = simd_idx % SIMD_LEN_qh; + int simd_trit = simd_idx / SIMD_LEN_qh; + + cur_qs = simd_qs[simd_byte] * pow3[simd_trit]; + trit = ((uint16_t) cur_qs * 3) >> 8; + } + + return trit + 1; + } + + static tmac_float_type get_scale(const void * data, int idx, int group_size) { + ggml_fp16_t d = ((const block_t *) data)[idx / group_size].d; + if (sizeof(tmac_float_type) == 2) { + tmac_float_type * fp16dp = reinterpret_cast(&d); + return *fp16dp; + } else { + return ggml_fp16_to_fp32(d); + } + } +}; + +struct BlockTQ20TypeAccessor { + using block_t = block_tq2_0; + + static constexpr int BITS = 2; + static constexpr int SIMD_LEN = 32; + static constexpr int group_size = (sizeof(block_t) - sizeof(ggml_fp16_t)) * 8 / BITS; // 256 + static constexpr int simd_n_elem = SIMD_LEN * 8 / BITS; // 128 + + static uint8_t get_q(const void * data, int idx) { + const uint8_t * qs = (const uint8_t *) ((((const block_t *) data)[idx / group_size]).qs); + int internal_idx = idx % group_size; + const uint8_t * simd_qs = qs + internal_idx / simd_n_elem * SIMD_LEN; + int simd_idx = internal_idx % simd_n_elem; + return (simd_qs[simd_idx % SIMD_LEN] >> (simd_idx / SIMD_LEN * BITS)) + 1; + } + + static tmac_float_type get_scale(const void * data, int idx, int group_size) { + ggml_fp16_t d = ((const block_t *) data)[idx / group_size].d; + if (sizeof(tmac_float_type) == 2) { + tmac_float_type * fp16dp = reinterpret_cast(&d); + return *fp16dp; + } else { + return ggml_fp16_to_fp32(d); + } + } +}; + bool ggml_tmac_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) { if ((is_type_supported(src0->type)) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && - src0->backend == GGML_BACKEND_TYPE_CPU) { + src0->backend == GGML_BACKEND_TYPE_CPU && + strcmp(src0->name, "token_embd.weight") && // means not equal + strcmp(src0->name, "output.weight")) { return true; } return false; @@ -212,10 +312,18 @@ void ggml_tmac_transform_tensor(struct ggml_tensor * tensor) { DLOG(INFO) << "kcfg (bm=" << bm << ", simd_n_in=" << simd_n_in << ", simd_n_out=" << simd_n_out << ", kfactor=" << kfactor << ", group_size=" << group_size << ", lut_scales_size=" << lut_scales_size << ", scales_size=" << scales_size << ", n_tile_num=" << n_tile_num << ")"; if (bm == 0) { - // Instead of fatal error, try to avoid using t-mac? - LOG(FATAL) << "Failed to find kcfg. Abort transforming"; - return; + // TODO: warning token.embd if not support + if (!strcmp(tensor->name, "token_embd.weight") || !strcmp(tensor->name, "output.weight")) { + LOG(WARNING) << "Do not find kcfg for " << tensor->name << ". Consider compiling T-MAC kernel for it if vocab size is a multiply of 128 or 320, detected " << tensor->ne[1] << "."; + return; + } + else { + // Instead of fatal error, try to avoid using t-mac? + LOG(FATAL) << "Failed to find kcfg. Abort transforming"; + return; + } } + const int mgroup = ngroups_per_elem * simd_n_in; m = m * bits; @@ -224,7 +332,7 @@ void ggml_tmac_transform_tensor(struct ggml_tensor * tensor) { scales = (tmac_float_type *) aligned_malloc(scales_size * sizeof(tmac_float_type)); if (do_permutate(tensor->type)) { - qweights = (uint8_t *) aligned_malloc(k * m / 8); + qweights = (uint8_t *) aligned_malloc(k * m / 8); } else { qweights = (uint8_t *) tensor->data; float * i2_scales = (float * )(qweights + k * m / 8); @@ -254,13 +362,20 @@ void ggml_tmac_transform_tensor(struct ggml_tensor * tensor) { // w = np.stack([(w >> ib) & 1 for ib in range(bits)], axis=-1) for (int im = 0; im < m / bits; im++) { for (int ik = 0; ik < k; ik++) { + uint8_t v; + if (tensor->type == GGML_TYPE_Q4_0) { + v = BlockQ40TypeAccessor::get_q(tensor->data, im * k + ik); + } else if (tensor->type == GGML_TYPE_I2) { + v = BlockI2TypeAccessor::get_q(tensor->data, im * k + ik); + } else if (tensor->type == GGML_TYPE_TQ1_0) { + v = BlockTQ10TypeAccessor::get_q(tensor->data, im * k + ik); + } else if (tensor->type == GGML_TYPE_TQ2_0) { + v = BlockTQ20TypeAccessor::get_q(tensor->data, im * k + ik); + } else { + LOG(FATAL) << "Unsupported type"; + } + for (int ib = 0; ib < bits; ib++) { - uint8_t v; - if (tensor->type == GGML_TYPE_Q4_0) { - v = BlockQ40TypeAccessor::get_q(tensor->data, im * k + ik); - } else if (tensor->type == GGML_TYPE_I2) { - v = BlockI2TypeAccessor::get_q(tensor->data, im * k + ik); - } buf1[im * k * bits + ik * bits + ib] = (v >> ib) & 1; } } @@ -353,6 +468,12 @@ void ggml_tmac_transform_tensor(struct ggml_tensor * tensor) { scale = BlockQ40TypeAccessor::get_scale(tensor->data, idx); } else if (tensor->type == GGML_TYPE_I2) { scale = BlockI2TypeAccessor::get_scale(i2_scales, idx, group_size); + } else if (tensor->type == GGML_TYPE_TQ1_0) { + scale = BlockTQ10TypeAccessor::get_scale(tensor->data, idx, group_size); + } else if (tensor->type == GGML_TYPE_TQ2_0) { + scale = BlockTQ20TypeAccessor::get_scale(tensor->data, idx, group_size); + } else { + LOG(FATAL) << "Unsupported type"; } int new_idx; idx = idx / group_size; @@ -388,6 +509,10 @@ int ggml_tmac_get_type_bits(enum ggml_type type) { return 4; case GGML_TYPE_Q4_0: return 4; + case GGML_TYPE_TQ1_0: + return 2; + case GGML_TYPE_TQ2_0: + return 2; default: return 0; } From 080d2ecc56da271a7a6c053ff3cac0d2b167efbc Mon Sep 17 00:00:00 2001 From: Qingtao Li Date: Wed, 30 Oct 2024 10:46:16 +0800 Subject: [PATCH 09/12] Add run_pipeline option of rechunk. --- ggml/src/CMakeLists.txt | 4 ++++ ggml/src/ggml-cpu.c | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 7850fe0b5294d..00b39d455266b 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -897,6 +897,10 @@ if (GGML_TMAC) set(GGML_EXTRA_LIBS_PRIVATE ${GGML_EXTRA_LIBS_PRIVATE} t_mac_no_tvm) set(GGML_SOURCES_TMAC ${GGML_SOURCES_TMAC} ${TMAC_KERNELS_SOURCE}) endif() + + if (GGML_TMAC_RECHUNK) + add_compile_definitions(TMAC_RECHUNK) + endif() else() message(WARNING "TMAC not found") endif() diff --git a/ggml/src/ggml-cpu.c b/ggml/src/ggml-cpu.c index adc58ad0a783d..68fda5f338fa3 100644 --- a/ggml/src/ggml-cpu.c +++ b/ggml/src/ggml-cpu.c @@ -7577,13 +7577,14 @@ UseGgmlGemm1:; int64_t dr0 = chunk_size0; int64_t dr1 = chunk_size1; - +#if defined(TMAC_RECHUNK) // Rechunk if ((nchunk1 == 1) && (nchunk0 > nth * 4)) { // dr0 should be divisible by chunk_size0 dr0 = (ne0 / (nth * 4) / chunk_size0) * chunk_size0; nchunk0 = (nr0 + dr0 - 1) / dr0; } +#endif int current_chunk = ith; From f84d25dd8fcf706e357b79ceda1437273d9b76ee Mon Sep 17 00:00:00 2001 From: Qingtao Li Date: Wed, 30 Oct 2024 16:00:38 +0800 Subject: [PATCH 10/12] Limit enable_t_mac to take effect on INT_N only. --- convert_hf_to_gguf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f44c4ca6d1dc7..7027948b9f3ef 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1798,8 +1798,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter ]): # transform weight into 1/0/-1 (in fp32) data_torch = self.weight_quant(data_torch) - if self.enable_t_mac: - # transform weight into T-MAC I2 format + if self.enable_t_mac and self.ftype == gguf.LlamaFileType.MOSTLY_INT_N: + # transform weight into T-MAC INT_N format from t_mac.model_utils import preprocess_for_t_mac data = LazyTorchTensor.to_eager(data_torch).numpy() scale = np.max(np.abs(data)) From 55a86969b8c15a82839dfc98e60c00657d9b37ec Mon Sep 17 00:00:00 2001 From: Qingtao Li Date: Tue, 5 Nov 2024 18:40:31 +0800 Subject: [PATCH 11/12] [rebase] Fix build error. --- ggml/src/ggml-cpu.c | 20 ++++++++++++++++++++ ggml/src/ggml.c | 16 ++++------------ 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cpu.c b/ggml/src/ggml-cpu.c index 68fda5f338fa3..db72c29d1a279 100644 --- a/ggml/src/ggml-cpu.c +++ b/ggml/src/ggml-cpu.c @@ -427,6 +427,26 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_I1] = { + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, + [GGML_TYPE_I2] = { + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, + [GGML_TYPE_I3] = { + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, + [GGML_TYPE_I4] = { + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, }; const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 823abbe175644..f2199c6e79f19 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -8,6 +8,10 @@ #include "ggml.h" #include "ggml-aarch64.h" +#if defined(GGML_USE_TMAC) +#include "ggml-tmac.h" +#endif + #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) @@ -546,36 +550,24 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .blck_size = 8, .type_size = sizeof(int8_t), .is_quantized = false, - .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, - .vec_dot_type = GGML_TYPE_F32, - .nrows = 1, }, [GGML_TYPE_I2] = { .type_name = "i2", .blck_size = 4, .type_size = sizeof(int8_t), .is_quantized = false, - .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, - .vec_dot_type = GGML_TYPE_F32, - .nrows = 1, }, [GGML_TYPE_I3] = { .type_name = "i3", .blck_size = 2, .type_size = sizeof(int8_t), .is_quantized = false, - .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, - .vec_dot_type = GGML_TYPE_F32, - .nrows = 1, }, [GGML_TYPE_I4] = { .type_name = "i4", .blck_size = 2, .type_size = sizeof(int8_t), .is_quantized = false, - .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, - .vec_dot_type = GGML_TYPE_F32, - .nrows = 1, }, [GGML_TYPE_I8] = { .type_name = "i8", From 3f7d85da1e68b081fea31a74abf1d88dd2a68463 Mon Sep 17 00:00:00 2001 From: Qingtao Li Date: Wed, 6 Nov 2024 13:39:12 +0800 Subject: [PATCH 12/12] [fix] Put ggml_tmac_init at correct place. --- ggml/src/ggml-cpu.c | 4 ---- ggml/src/ggml.c | 5 +++++ 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cpu.c b/ggml/src/ggml-cpu.c index db72c29d1a279..44e552aee2145 100644 --- a/ggml/src/ggml-cpu.c +++ b/ggml/src/ggml-cpu.c @@ -13891,10 +13891,6 @@ void ggml_cpu_init(void) { ggml_init_arm_arch_features(); #endif -#if defined(GGML_USE_TMAC) - ggml_tmac_init(); -#endif - is_first_call = false; } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f2199c6e79f19..1bc193bf442e7 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1453,6 +1453,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { } u = {i}; ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16); } + +#if defined(GGML_USE_TMAC) + ggml_tmac_init(); +#endif + is_first_call = true; }