Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce New Lookup-Table(LUT)-Based Matrix Multiplication Method #10181

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
133 changes: 125 additions & 8 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
model_name: str | None
metadata_override: Path | None
dir_model_card: Path
enable_t_mac: bool
kcfg_file: Path | None

# subclasses should define this!
model_arch: gguf.MODEL_ARCH
Expand All @@ -73,7 +75,8 @@
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None):
small_first_shard: bool = False, hparams: dict[str, Any] | None = None,
enable_t_mac: bool = False, kcfg_file: Path | None = None):
if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")

Expand All @@ -95,7 +98,8 @@
self.metadata_override = metadata_override
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py

self.enable_t_mac = enable_t_mac
self.kcfg_file = kcfg_file
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
if self.ftype == gguf.LlamaFileType.GUESSED:
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
Expand Down Expand Up @@ -265,6 +269,73 @@

return [(self.map_tensor_name(name), data_torch)]

_gptq_quant_dict: dict[str, Tensor] | None = None
_t_mac_bits: int = 0
_t_mac_raw_shape: tuple[int, ...] | None = None

# Repack and merge qweight, scales, and qzeros into a single tensor
# Currently, this logic is nearly impossible to be implemented in quants.py
def _modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if not self.enable_t_mac:
return self.modify_tensors(data_torch, name, bid)

# bits = 0 means not quantized
self._t_mac_bits = 0
self._t_mac_raw_shape = None

from t_mac.model_utils import get_quantization_config, preprocess_for_t_mac

Check failure on line 286 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Import "t_mac.model_utils" could not be resolved (reportMissingImports)
quantization_config = get_quantization_config(self.dir_model)

if quantization_config["quant_method"] == "gptq": # AutoGPTQ/GPTQModel
if name.endswith(".g_idx"):
return []

if name.endswith(".qweight") or name.endswith(".scales") or name.endswith(".qzeros"):
if self._gptq_quant_dict is None:
self._gptq_quant_dict = {}
suffix = "." + name.split(".")[-1]
base_name = name.replace(suffix, "")
self._gptq_quant_dict.setdefault(base_name, {})[suffix] = data_torch

Check failure on line 298 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Argument of type "str" cannot be assigned to parameter "indices" of type "SupportsIndex | _bool | _int | slice | ellipsis | Tensor | _NestedSequence[_bool | _int | slice | ellipsis | Tensor | None] | tuple[SupportsIndex | _bool | _int | slice | ellipsis | Tensor | _NestedSequence[_bool | _int | slice | ellipsis | Tensor | None] | None, ...] | None" in function "__setitem__"   Type "str" is not assignable to type "SupportsIndex | _bool | _int | slice | ellipsis | Tensor | _NestedSequence[_bool | _int | slice | ellipsis | Tensor | None] | tuple[SupportsIndex | _bool | _int | slice | ellipsis | Tensor | _NestedSequence[_bool | _int | slice | ellipsis | Tensor | None] | None, ...] | None"     "str" is incompatible with protocol "SupportsIndex"       "__index__" is not present     "str" is not assignable to "None"     "str" is not assignable to "bool"     "str" is not assignable to "int"     "str" is not assignable to "slice"     "str" is not assignable to "ellipsis" (reportArgumentType)

Check failure on line 298 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Argument of type "dict[Any, Any]" cannot be assigned to parameter "default" of type "Tensor" in function "setdefault" (reportArgumentType)
if len(self._gptq_quant_dict[base_name]) < 3:
return []

qweight = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".qweight"]).numpy()

Check failure on line 302 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Argument of type "Literal['.qweight']" cannot be assigned to parameter "indices" of type "SupportsIndex | _bool | _int | slice | ellipsis | Tensor | _NestedSequence[_bool | _int | slice | ellipsis | Tensor | None] | tuple[SupportsIndex | _bool | _int | slice | ellipsis | Tensor | _NestedSequence[_bool | _int | slice | ellipsis | Tensor | None] | None, ...] | None" in function "__getitem__"   Type "Literal['.qweight']" is not assignable to type "SupportsIndex | _bool | _int | slice | ellipsis | Tensor | _NestedSequence[_bool | _int | slice | ellipsis | Tensor | None] | tuple[SupportsIndex | _bool | _int | slice | ellipsis | Tensor | _NestedSequence[_bool | _int | slice | ellipsis | Tensor | None] | None, ...] | None"     "Literal['.qweight']" is incompatible with protocol "SupportsIndex"       "__index__" is not present     "Literal['.qweight']" is not assignable to "None"     "Literal['.qweight']" is not assignable to "bool"     "Literal['.qweight']" is not assignable to "int"     "Literal['.qweight']" is not assignable to "slice"     "Literal['.qweight']" is not assignable to "ellipsis" (reportArgumentType)
scales = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".scales"]).numpy()

Check failure on line 303 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Argument of type "Literal['.scales']" cannot be assigned to parameter "indices" of type "SupportsIndex | _bool | _int | slice | ellipsis | Tensor | _NestedSequence[_bool | _int | slice | ellipsis | Tensor | None] | tuple[SupportsIndex | _bool | _int | slice | ellipsis | Tensor | _NestedSequence[_bool | _int | slice | ellipsis | Tensor | None] | None, ...] | None" in function "__getitem__"   Type "Literal['.scales']" is not assignable to type "SupportsIndex | _bool | _int | slice | ellipsis | Tensor | _NestedSequence[_bool | _int | slice | ellipsis | Tensor | None] | tuple[SupportsIndex | _bool | _int | slice | ellipsis | Tensor | _NestedSequence[_bool | _int | slice | ellipsis | Tensor | None] | None, ...] | None"     "Literal['.scales']" is incompatible with protocol "SupportsIndex"       "__index__" is not present     "Literal['.scales']" is not assignable to "None"     "Literal['.scales']" is not assignable to "bool"     "Literal['.scales']" is not assignable to "int"     "Literal['.scales']" is not assignable to "slice"     "Literal['.scales']" is not assignable to "ellipsis" (reportArgumentType)
qzeros = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".qzeros"]).numpy()

Check failure on line 304 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Argument of type "Literal['.qzeros']" cannot be assigned to parameter "indices" of type "SupportsIndex | _bool | _int | slice | ellipsis | Tensor | _NestedSequence[_bool | _int | slice | ellipsis | Tensor | None] | tuple[SupportsIndex | _bool | _int | slice | ellipsis | Tensor | _NestedSequence[_bool | _int | slice | ellipsis | Tensor | None] | None, ...] | None" in function "__getitem__"   Type "Literal['.qzeros']" is not assignable to type "SupportsIndex | _bool | _int | slice | ellipsis | Tensor | _NestedSequence[_bool | _int | slice | ellipsis | Tensor | None] | tuple[SupportsIndex | _bool | _int | slice | ellipsis | Tensor | _NestedSequence[_bool | _int | slice | ellipsis | Tensor | None] | None, ...] | None"     "Literal['.qzeros']" is incompatible with protocol "SupportsIndex"       "__index__" is not present     "Literal['.qzeros']" is not assignable to "None"     "Literal['.qzeros']" is not assignable to "bool"     "Literal['.qzeros']" is not assignable to "int"     "Literal['.qzeros']" is not assignable to "slice"     "Literal['.qzeros']" is not assignable to "ellipsis" (reportArgumentType)
name = base_name + ".weight"
from t_mac.model_utils import unpack_gptqv2

Check failure on line 306 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Import "t_mac.model_utils" could not be resolved (reportMissingImports)
w, scales, zeros, bits, group_size = unpack_gptqv2(qweight, scales, qzeros, "gptqmodel" in quantization_config["quantizer"])
self._t_mac_bits = bits
self._t_mac_raw_shape = w.shape
if bits != quantization_config["bits"] or group_size != quantization_config["group_size"]:
logger.warning("Error while parsing weights for quantization_config: {}".format(quantization_config))

# For permutation in, e.g., LlamaModel
w = self.modify_tensors(torch.from_numpy(w), name, bid)[0][1].numpy()

Check failure on line 314 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

"__getitem__" method not defined on type "Iterable[tuple[str, Tensor]]" (reportIndexIssue)
scales = self.modify_tensors(torch.from_numpy(scales), name, bid)[0][1].numpy()

Check failure on line 315 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

"__getitem__" method not defined on type "Iterable[tuple[str, Tensor]]" (reportIndexIssue)
zeros = self.modify_tensors(torch.from_numpy(zeros), name, bid)[0][1].numpy()

Check failure on line 316 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

"__getitem__" method not defined on type "Iterable[tuple[str, Tensor]]" (reportIndexIssue)

if self.ftype == gguf.LlamaFileType.MOSTLY_INT_N:
if quantization_config["sym"]:
if not np.allclose(zeros, np.zeros_like(zeros)):
logger.warning("Although the quantized model claimed to be symmetric, the weights are asymmetric")
else:
zeros = None
data_torch = torch.from_numpy(preprocess_for_t_mac(self.kcfg_file, w, scales, zeros, bits=bits))
else:
old_shape = w.shape
w = w.astype("float32").reshape(-1, group_size)
scales = scales.astype("float32").reshape(-1, 1)
zeros = zeros.astype("float32").reshape(-1, 1)
data = (w - (zeros / scales + (2 ** (bits - 1)))) * scales
data_torch = torch.from_numpy(data.reshape(old_shape))
if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
data_torch = data_torch.to(torch.float16)

return [(self.map_tensor_name(name), data_torch)]

return self.modify_tensors(data_torch, name, bid)

def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
del name, new_name, bid, n_dims # unused

Expand All @@ -285,7 +356,7 @@
old_dtype = data_torch.dtype

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
if data_torch.dtype not in (torch.float16, torch.float32) and not self.enable_t_mac:
data_torch = data_torch.to(torch.float32)

# use the first number-like part of the tensor name as the block id
Expand All @@ -295,7 +366,13 @@
bid = int(part)
break

for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
for new_name, data_torch in (self._modify_tensors(data_torch, name, bid)):
# Some GPTQ models have empty bias tensors which are not in the model architecture.
# These tensors will cause tensor number check to fail, so we have to skip them.
if new_name.endswith(".bias") and np.all(LazyTorchTensor.to_eager(data_torch).numpy() == 0):
logger.info(f"Skipping empty bias tensor: {new_name}")
continue

data = data_torch.squeeze().numpy()

# if data ends up empty, it means data_torch was a scalar tensor -> restore
Expand Down Expand Up @@ -344,6 +421,19 @@
# TODO: use Q4_K and Q6_K
data_qtype = gguf.GGMLQuantizationType.F16

# If self._t_mac_bits > 0, the tensor is quantized by GPTQ
if self.enable_t_mac and self._t_mac_bits > 0 and self.ftype == gguf.LlamaFileType.MOSTLY_INT_N:
if self._t_mac_bits == 1:
data_qtype = gguf.GGMLQuantizationType.I1
elif self._t_mac_bits == 2:
data_qtype = gguf.GGMLQuantizationType.I2
elif self._t_mac_bits == 3:
data_qtype = gguf.GGMLQuantizationType.I3
elif self._t_mac_bits == 4:
data_qtype = gguf.GGMLQuantizationType.I4
else:
raise ValueError(f"Unsupported number of bits: {self._t_mac_bits}")

# No override (data_qtype is False), or wants to be quantized (data_qtype is True)
if isinstance(data_qtype, bool):
if self.ftype == gguf.LlamaFileType.ALL_F32:
Expand All @@ -358,6 +448,12 @@
data_qtype = gguf.GGMLQuantizationType.TQ1_0
elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0:
data_qtype = gguf.GGMLQuantizationType.TQ2_0
elif self.ftype == gguf.LlamaFileType.MOSTLY_INT_N:
# If the tensor is successfully quantized, data_qtype should be I1/2/3/4
# If data_qtype is still bool, then the tensor should not be quantized
# In practice, this tensor is `output.weight` for GPTQ models
# TODO: Consider quantizing it?
data_qtype = gguf.GGMLQuantizationType.F16
else:
raise ValueError(f"Unknown file type: {self.ftype.name}")

Expand All @@ -369,14 +465,16 @@
data = gguf.quants.quantize(data, data_qtype)

shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape
shape = self._t_mac_raw_shape or shape

# reverse shape to make it similar to the internal ggml dimension order
shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}"

# n_dims is implicit in the shape
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")

self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
raw_shape = gguf.quant_shape_to_byte_shape(self._t_mac_raw_shape, data_qtype) if self.ftype == gguf.LlamaFileType.MOSTLY_INT_N and self._t_mac_raw_shape else None
self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype, raw_shape=raw_shape)

def set_type(self):
self.gguf_writer.add_type(gguf.GGUFType.MODEL)
Expand Down Expand Up @@ -1700,6 +1798,15 @@
]):
# transform weight into 1/0/-1 (in fp32)
data_torch = self.weight_quant(data_torch)
if self.enable_t_mac and self.ftype == gguf.LlamaFileType.MOSTLY_INT_N:
# transform weight into T-MAC INT_N format
from t_mac.model_utils import preprocess_for_t_mac
data = LazyTorchTensor.to_eager(data_torch).numpy()
scale = np.max(np.abs(data))
w = np.round(data / scale + 2).astype(np.uint8)
data_torch = torch.from_numpy(preprocess_for_t_mac(self.kcfg_file, w, scale.reshape(1), bits=2))
self._t_mac_bits = 2
self._t_mac_raw_shape = w.shape

yield (new_name, data_torch)

Expand Down Expand Up @@ -4297,8 +4404,8 @@
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "int_n", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and int_n for int1/2/3/4, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
)
parser.add_argument(
"--bigendian", action="store_true",
Expand Down Expand Up @@ -4344,6 +4451,14 @@
"--metadata", type=Path,
help="Specify the path for an authorship metadata override file"
)
parser.add_argument(
"--enable-t-mac", action="store_true",
help="Enable T-MAC quantization format (disabled by default). Support GPTQ, GPTQv2, BitNet and BitDistiller."
)
parser.add_argument(
"--kcfg", type=Path,
help="Specify the path for the T-MAC configuration file"
)

return parser.parse_args()

Expand Down Expand Up @@ -4387,6 +4502,7 @@
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
"tq1_0": gguf.LlamaFileType.MOSTLY_TQ1_0,
"tq2_0": gguf.LlamaFileType.MOSTLY_TQ2_0,
"int_n": gguf.LlamaFileType.MOSTLY_INT_N,
"auto": gguf.LlamaFileType.GUESSED,
}

Expand Down Expand Up @@ -4420,7 +4536,8 @@
metadata_override=args.metadata, model_name=args.model_name,
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split)
small_first_shard=args.no_tensor_first_split,
enable_t_mac=args.enable_t_mac, kcfg_file=args.kcfg)

if args.vocab_only:
logger.info("Exporting model vocab...")
Expand Down
3 changes: 3 additions & 0 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ option(GGML_SYCL "ggml: use SYCL"
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
"ggml: sycl target device")
option(GGML_TMAC "ggml: use TMAC" OFF)
option(GGML_TMAC_SYSLIB "ggml: use TMAC system library" OFF)
option(GGML_TMAC_TVM_THREADPOOL "ggml: use TVM threadpool for TMAC" OFF)

# extra artifacts
option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE})
Expand Down
42 changes: 42 additions & 0 deletions ggml/include/ggml-tmac.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

#include "ggml.h"
#include "ggml-backend.h"

#ifdef __ARM_NEON
#include <arm_neon.h>
typedef float16_t tmac_float_type;
#else
typedef float tmac_float_type;
#endif

#ifdef __cplusplus
extern "C" {
#endif

struct tmac_tensor_extra {
int lut_scales_size;
int scales_size;
int n_tile_num;
uint8_t * qweights;
tmac_float_type * scales;
};

GGML_API void ggml_tmac_init(void);
GGML_API void ggml_tmac_free(void);
// src0->type == Q4_0/IQ2_XXS/IQ3_XXS
// T-MAC currently only supports BitNet quantization or GPTQ-like quantization (only scales, without zeros)
// If use i-quantization gguf models, the results will be wrong
// TODO: add customized block types Q2_0/Q3_0
GGML_API bool ggml_tmac_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
GGML_API size_t ggml_tmac_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
GGML_API void ggml_tmac_mul_mat_task_init(void * src1, void * qlut, void * lut_scales, void * lut_biases, int n, int k, int m, int bits);
GGML_API void ggml_tmac_mul_mat_task_compute(void * src0, void * scales, void * qlut, void * lut_scales, void * lut_biases, void * dst, int n, int k, int m, int bits);
GGML_API void ggml_tmac_transform_tensor(struct ggml_tensor * tensor);
GGML_API int ggml_tmac_get_type_bits(enum ggml_type type);
GGML_API void ggml_tmac_set_n_threads(int n_threads);
GGML_API size_t ggml_tmac_get_nbytes(const struct ggml_tensor * tensor);

#ifdef __cplusplus
}
#endif
4 changes: 4 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ extern "C" {
GGML_TYPE_Q4_0_8_8 = 33,
GGML_TYPE_TQ1_0 = 34,
GGML_TYPE_TQ2_0 = 35,
GGML_TYPE_I1 = 36,
GGML_TYPE_I2 = 37,
GGML_TYPE_I3 = 38,
GGML_TYPE_I4 = 39,
GGML_TYPE_COUNT,
};

Expand Down
61 changes: 60 additions & 1 deletion ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,42 @@ if (GGML_KOMPUTE)
endif()
endif()

if (GGML_TMAC)
find_package(TMAC)

if (TMAC_FOUND)
message(STATUS "TMAC found")

list(APPEND GGML_CDEF_PUBLIC GGML_USE_TMAC)

set(GGML_HEADERS_TMAC ../include/ggml-tmac.h)
set(GGML_SOURCES_TMAC ggml-tmac.cpp)

link_directories(${TMAC_LIB_DIR})
file(COPY ${TMAC_LIB_DIR}/kcfg.ini DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
# TODO: link t_mac_object when GGML_TMAC_SYSLIB

if (GGML_TMAC_TVM_THREADPOOL)
add_compile_definitions(TMAC_USE_TVM_THREADPOOL)
set(GGML_EXTRA_LIBS_PRIVATE ${GGML_EXTRA_LIBS_PRIVATE} t_mac)
else()
if ((NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") OR
(NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang"))
message(FATAL_ERROR "Clang is required for T-MAC compilation")
endif()

set(GGML_EXTRA_LIBS_PRIVATE ${GGML_EXTRA_LIBS_PRIVATE} t_mac_no_tvm)
set(GGML_SOURCES_TMAC ${GGML_SOURCES_TMAC} ${TMAC_KERNELS_SOURCE})
endif()

if (GGML_TMAC_RECHUNK)
add_compile_definitions(TMAC_RECHUNK)
endif()
else()
message(WARNING "TMAC not found")
endif()
endif()

if (GGML_CPU_HBM)
find_library(memkind memkind REQUIRED)

Expand Down Expand Up @@ -1170,6 +1206,26 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR
# Raspberry Pi 3, 4, Zero 2 (32-bit)
list(APPEND ARCH_FLAGS -mno-unaligned-access)
endif()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64" AND GGML_TMAC AND TMAC_FOUND)
# We need fullfp16 for T-MAC
# TODO: we need to simplify this logic through check_cxx_source_compiles or Presets?
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vmlaq_f32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8)
if (GGML_COMPILER_SUPPORT_MATMUL_INT8)
# Device with armv8.7a+ cpu, e.g., WSL on Surface Laptop 7
# based on arm64-windows-llvm.cmake
list(APPEND ARCH_FLAGS -march=armv8.7-a+fp16 -fvectorize -ffp-model=fast -fno-finite-math-only)
add_compile_definitions(__ARM_FEATURE_MATMUL_INT8)
else ()
# Jetson AGX Orin, Raspberry Pi 5
list(APPEND ARCH_FLAGS -march=armv8.2a+fp16)
endif ()
endif()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ARM64" AND GGML_TMAC AND TMAC_FOUND)
# ARM Windows with LLVM clang GNU interface
# We need fullfp16 for T-MAC
# TODO: check_cxx_source_compiles
list(APPEND ARCH_FLAGS -march=armv8.2a+fp16)
endif()
if (GGML_SVE)
list(APPEND ARCH_FLAGS -march=armv8.6-a+sve)
endif()
Expand All @@ -1184,7 +1240,9 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
# TODO: improve, should not reference files from the parent folder
include(../cmake/FindSIMD.cmake)
endif ()
if (GGML_AVX512)
# Can't use GGML_AVX512 with Clang for MSVC
# with error: conflicting types for '_m_prefetchw
if (GGML_AVX512 AND (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") AND (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang"))
list(APPEND ARCH_FLAGS /arch:AVX512)
# MSVC has no compile-time flags enabling specific
# AVX512 extensions, neither it defines the
Expand Down Expand Up @@ -1388,6 +1446,7 @@ add_library(ggml
${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
${GGML_SOURCES_AMX} ${GGML_HEADERS_AMX}
${GGML_SOURCES_CANN} ${GGML_HEADERS_CANN}
${GGML_SOURCES_TMAC} ${GGML_HEADERS_TMAC}
ggml-aarch64.c ggml-aarch64.h
)

Expand Down
Loading
Loading