diff --git a/.devops/llama-server-cuda.Dockerfile b/.devops/llama-server-cuda.Dockerfile index 0010ffd4c5465..7bef07a05f062 100644 --- a/.devops/llama-server-cuda.Dockerfile +++ b/.devops/llama-server-cuda.Dockerfile @@ -30,8 +30,10 @@ RUN make -j$(nproc) llama-server FROM ${BASE_CUDA_RUN_CONTAINER} as runtime RUN apt-get update && \ - apt-get install -y libcurl4-openssl-dev libgomp1 + apt-get install -y libcurl4-openssl-dev libgomp1 curl COPY --from=build /app/llama-server /llama-server +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + ENTRYPOINT [ "/llama-server" ] diff --git a/.devops/llama-server-intel.Dockerfile b/.devops/llama-server-intel.Dockerfile index cec43645233d1..3bf1670ec40a4 100644 --- a/.devops/llama-server-intel.Dockerfile +++ b/.devops/llama-server-intel.Dockerfile @@ -20,10 +20,12 @@ RUN if [ "${LLAMA_SYCL_F16}" = "ON" ]; then \ FROM intel/oneapi-basekit:$ONEAPI_VERSION as runtime RUN apt-get update && \ - apt-get install -y libcurl4-openssl-dev + apt-get install -y libcurl4-openssl-dev curl COPY --from=build /app/build/bin/llama-server /llama-server ENV LC_ALL=C.utf8 +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + ENTRYPOINT [ "/llama-server" ] diff --git a/.devops/llama-server-rocm.Dockerfile b/.devops/llama-server-rocm.Dockerfile index f88cf20e5b981..4b1cdc32090e6 100644 --- a/.devops/llama-server-rocm.Dockerfile +++ b/.devops/llama-server-rocm.Dockerfile @@ -43,8 +43,10 @@ ENV CXX=/opt/rocm/llvm/bin/clang++ # Enable cURL ENV LLAMA_CURL=1 RUN apt-get update && \ - apt-get install -y libcurl4-openssl-dev + apt-get install -y libcurl4-openssl-dev curl RUN make -j$(nproc) llama-server +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + ENTRYPOINT [ "/app/llama-server" ] diff --git a/.devops/llama-server-vulkan.Dockerfile b/.devops/llama-server-vulkan.Dockerfile index b0fa0b8e656b5..2bc2e45d3d676 100644 --- a/.devops/llama-server-vulkan.Dockerfile +++ b/.devops/llama-server-vulkan.Dockerfile @@ -5,15 +5,11 @@ FROM ubuntu:$UBUNTU_VERSION as build # Install build tools RUN apt update && apt install -y git build-essential cmake wget -# Install Vulkan SDK +# Install Vulkan SDK and cURL RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \ wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list && \ apt update -y && \ - apt-get install -y vulkan-sdk - -# Install cURL -RUN apt-get update && \ - apt-get install -y libcurl4-openssl-dev + apt-get install -y vulkan-sdk libcurl4-openssl-dev curl # Build it WORKDIR /app @@ -28,4 +24,6 @@ RUN cp /app/build/bin/llama-server /llama-server && \ ENV LC_ALL=C.utf8 +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + ENTRYPOINT [ "/llama-server" ] diff --git a/.devops/llama-server.Dockerfile b/.devops/llama-server.Dockerfile index aa93369bebebe..a53a5c999c8cd 100644 --- a/.devops/llama-server.Dockerfile +++ b/.devops/llama-server.Dockerfile @@ -3,7 +3,7 @@ ARG UBUNTU_VERSION=22.04 FROM ubuntu:$UBUNTU_VERSION as build RUN apt-get update && \ - apt-get install -y build-essential git libcurl4-openssl-dev + apt-get install -y build-essential git libcurl4-openssl-dev curl WORKDIR /app @@ -22,4 +22,6 @@ COPY --from=build /app/llama-server /llama-server ENV LC_ALL=C.utf8 +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + ENTRYPOINT [ "/llama-server" ] diff --git a/CMakeLists.txt b/CMakeLists.txt index 49ba45356a78d..1acf4bb08ba17 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -102,7 +102,8 @@ option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM" option(LLAMA_CUDA "llama: use CUDA" OFF) option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF) option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) -option(LLAMA_CUDA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF) +option(LLAMA_CUDA_FORCE_MMQ "llama: always use mmq kernels instead of cuBLAS" OFF) +option(LLAMA_CUDA_FORCE_CUBLAS "llama: always use cuBLAS instead of mmq kernels" OFF) set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF) @@ -416,13 +417,14 @@ if (LLAMA_CUDA) if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) # 52 == lowest CUDA 12 standard - # 60 == f16 CUDA intrinsics + # 60 == FP16 CUDA intrinsics # 61 == integer CUDA intrinsics - # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster + # 70 == FP16 tensor cores + # 75 == int8 tensor cores if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16) - set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics + set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75") else() - set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics + set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75") #set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work endif() endif() @@ -447,6 +449,9 @@ if (LLAMA_CUDA) if (LLAMA_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() + if (LLAMA_CUDA_FORCE_CUBLAS) + add_compile_definitions(GGML_CUDA_FORCE_CUBLAS) + endif() if (LLAMA_CUDA_NO_VMM) add_compile_definitions(GGML_CUDA_NO_VMM) endif() diff --git a/Makefile b/Makefile index 3aad77394c5ac..f6e8eb73eb9eb 100644 --- a/Makefile +++ b/Makefile @@ -537,6 +537,9 @@ endif # LLAMA_CUDA_FORCE_DMMV ifdef LLAMA_CUDA_FORCE_MMQ MK_NVCCFLAGS += -DGGML_CUDA_FORCE_MMQ endif # LLAMA_CUDA_FORCE_MMQ +ifdef LLAMA_CUDA_FORCE_CUBLAS + MK_NVCCFLAGS += -DGGML_CUDA_FORCE_CUBLAS +endif # LLAMA_CUDA_FORCE_CUBLAS ifdef LLAMA_CUDA_DMMV_X MK_NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) else diff --git a/README.md b/README.md index 40793c8eab880..a54ee3951d41d 100644 --- a/README.md +++ b/README.md @@ -510,8 +510,9 @@ Building the program with BLAS support may lead to some performance improvements |--------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. | | LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | - | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. | - | LLAMA_CUDA_FORCE_MMQ | Boolean | false | Force the use of dequantization + matrix multiplication kernels instead of leveraging Math libraries. | | + | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. | + | LLAMA_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). Speed for large batch sizes will be worse but VRAM consumption will be lower. | + | LLAMA_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models | | LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | | LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | diff --git a/common/common.cpp b/common/common.cpp index 1dc53265134a7..c76d0e2c33be5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1263,11 +1263,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } // cvector params - if (arg == "--completions-file") { - CHECK_ARG - params.cvector_completions_file = argv[i]; - return true; - } if (arg == "--positive-file") { CHECK_ARG params.cvector_positive_file = argv[i]; @@ -1278,11 +1273,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cvector_negative_file = argv[i]; return true; } - if (arg == "--completions") { - CHECK_ARG - params.n_completions = std::stoi(argv[i]); - return true; - } if (arg == "--pca-batch") { CHECK_ARG params.n_pca_batch = std::stoi(argv[i]); @@ -1293,6 +1283,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.n_pca_iterations = std::stoi(argv[i]); return true; } + if (arg == "--method") { + CHECK_ARG + std::string value(argv[i]); + /**/ if (value == "pca") { params.cvector_dimre_method = DIMRE_METHOD_PCA; } + else if (value == "mean") { params.cvector_dimre_method = DIMRE_METHOD_MEAN; } + else { invalid_param = true; } + return true; + } #ifndef LOG_DISABLE_LOGS // Parse args for logging parameters if (log_param_single_parse(argv[i])) { @@ -1444,7 +1442,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "main", " --cfg-negative-prompt-file FNAME", "negative prompt file to use for guidance" }); options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale }); - + options.push_back({ "main", " --chat-template JINJA_TEMPLATE", + "set custom jinja chat template (default: template taken from model's metadata)\n" + "only commonly used templates are accepted:\n" + "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); options.push_back({ "grammar" }); options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() }); options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" }); @@ -1538,9 +1539,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --lora FNAME", "apply LoRA adapter (implies --no-mmap)" }); options.push_back({ "*", " --lora-scaled FNAME S", "apply LoRA adapter with user defined scaling S (implies --no-mmap)" }); options.push_back({ "*", " --lora-base FNAME", "optional model to use as a base for the layers modified by the LoRA adapter" }); - options.push_back({ "*", " --control-vector FNAME", "add a control vector" }); + options.push_back({ "*", " --control-vector FNAME", "add a control vector\n" + "note: this argument can be repeated to add multiple control vectors" }); options.push_back({ "*", " --control-vector-scaled FNAME SCALE", - "add a control vector with user defined scaling SCALE" }); + "add a control vector with user defined scaling SCALE\n" + "note: this argument can be repeated to add multiple scaled control vectors" }); options.push_back({ "*", " --control-vector-layer-range START END", "layer range to apply the control vector(s) to, start and end inclusive" }); options.push_back({ "*", "-m, --model FNAME", "model path (default: models/$filename with filename from --hf-file\n" @@ -1621,11 +1624,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "cvector", "-o, --output FNAME", "output file (default: '%s')", params.cvector_outfile.c_str() }); options.push_back({ "cvector", " --positive-file FNAME", "positive prompts file, one prompt per line (default: '%s')", params.cvector_positive_file.c_str() }); options.push_back({ "cvector", " --negative-file FNAME", "negative prompts file, one prompt per line (default: '%s')", params.cvector_negative_file.c_str() }); - options.push_back({ "cvector", " --completions-file FNAME", - "completions file (default: '%s')", params.cvector_completions_file.c_str() }); - options.push_back({ "cvector", " --completions N", "number of lines of completions file to use (default: %d)", params.n_completions }); options.push_back({ "cvector", " --pca-batch N", "batch size used for PCA. Larger batch runs faster, but uses more memory (default: %d)", params.n_pca_batch }); options.push_back({ "cvector", " --pca-iter N", "number of iterations used for PCA (default: %d)", params.n_pca_iterations }); + options.push_back({ "cvector", " --method {pca,mean}", "dimensionality reduction method to be used (default: pca)" }); printf("usage: %s [options]\n", argv[0]); @@ -2602,12 +2603,67 @@ bool llama_should_add_bos_token(const llama_model * model) { return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); } +// +// Chat template utils +// + bool llama_chat_verify_template(const std::string & tmpl) { llama_chat_message chat[] = {{"user", "test"}}; int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); return res >= 0; } +std::string llama_chat_apply_template(const struct llama_model * model, + const std::string & tmpl, + const std::vector & msgs, + bool add_ass) { + int alloc_size = 0; + std::vector chat; + for (auto & msg : msgs) { + chat.push_back({msg.role.c_str(), msg.content.c_str()}); + alloc_size += (msg.role.size() + msg.content.size()) * 1.25; + } + + const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); + std::vector buf(alloc_size); + + // run the first time to get the total output length + int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + + // if it turns out that our buffer is too small, we resize it + if ((size_t) res > buf.size()) { + buf.resize(res); + res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + } + + std::string formatted_chat(buf.data(), res); + return formatted_chat; +} + +std::string llama_chat_format_single(const struct llama_model * model, + const std::string & tmpl, + const std::vector & past_msg, + const llama_chat_msg & new_msg, + bool add_ass) { + auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false); + std::vector chat_new(past_msg); + chat_new.push_back(new_msg); + auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass); + auto formatted = fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); + return formatted; +} + +std::string llama_chat_format_example(const struct llama_model * model, + const std::string & tmpl) { + std::vector msgs = { + {"system", "You are a helpful assistant"}, + {"user", "Hello"}, + {"assistant", "Hi there"}, + {"user", "How are you?"}, + }; + return llama_chat_apply_template(model, tmpl, msgs, true); +} + // // KV cache utils // diff --git a/common/common.h b/common/common.h index a5c738f8b643f..c541204f6743b 100644 --- a/common/common.h +++ b/common/common.h @@ -52,6 +52,12 @@ int32_t cpu_get_num_math(); // CLI argument parsing // +// dimensionality reduction methods, used by cvector-generator +enum dimre_method { + DIMRE_METHOD_PCA, + DIMRE_METHOD_MEAN, +}; + struct gpt_params { uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed @@ -238,13 +244,12 @@ struct gpt_params { bool compute_ppl = true; // whether to compute perplexity // cvector-generator params - int n_completions = 64; - int n_pca_batch = 20; + int n_pca_batch = 100; int n_pca_iterations = 1000; - std::string cvector_outfile = "control_vector.gguf"; - std::string cvector_completions_file = "examples/cvector-generator/completions.txt"; - std::string cvector_positive_file = "examples/cvector-generator/positive.txt"; - std::string cvector_negative_file = "examples/cvector-generator/negative.txt"; + dimre_method cvector_dimre_method = DIMRE_METHOD_PCA; + std::string cvector_outfile = "control_vector.gguf"; + std::string cvector_positive_file = "examples/cvector-generator/positive.txt"; + std::string cvector_negative_file = "examples/cvector-generator/negative.txt"; }; void gpt_params_handle_model_default(gpt_params & params); @@ -365,9 +370,32 @@ bool llama_should_add_bos_token(const llama_model * model); // Chat template utils // +// same with llama_chat_message, but uses std::string +struct llama_chat_msg { + std::string role; + std::string content; +}; + // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid bool llama_chat_verify_template(const std::string & tmpl); +// CPP wrapper for llama_chat_apply_template +std::string llama_chat_apply_template(const struct llama_model * model, + const std::string & tmpl, + const std::vector & chat, + bool add_ass); + +// Format single message, while taking into account the position of that message in chat history +std::string llama_chat_format_single(const struct llama_model * model, + const std::string & tmpl, + const std::vector & past_msg, + const llama_chat_msg & new_msg, + bool add_ass); + +// Returns an example of formatted chat +std::string llama_chat_format_example(const struct llama_model * model, + const std::string & tmpl); + // // KV cache utils // diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index ebfc1a1f94b78..c26fad9307f15 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -65,7 +65,8 @@ class Model: # subclasses should define this! model_arch: gguf.MODEL_ARCH - def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, model_name: str | None): + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, + model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") self.dir_model = dir_model @@ -96,7 +97,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, ftype_lw: str = ftype_up.lower() # allow templating the file name with the output ftype, useful with the "auto" ftype self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up) - self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) + self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, + split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) @classmethod def __init_subclass__(cls): @@ -332,6 +334,8 @@ def write(self): self.gguf_writer.close() def write_vocab(self): + if len(self.gguf_writer.tensors) != 1: + raise ValueError('Splitting the vocabulary is not supported') self.gguf_writer.write_header_to_file(self.fname_out) self.gguf_writer.write_kv_data_to_file() self.gguf_writer.close() @@ -2974,10 +2978,44 @@ def parse_args() -> argparse.Namespace: "--verbose", action="store_true", help="increase output verbosity", ) + parser.add_argument( + "--split-max-tensors", type=int, default=0, + help="max tensors in each split", + ) + parser.add_argument( + "--split-max-size", type=str, default="0", + help="max size per split N(M|G)", + ) + parser.add_argument( + "--dry-run", action="store_true", + help="only print out a split plan and exit, without writing any new files", + ) + parser.add_argument( + "--no-tensor-first-split", action="store_true", + help="do not add tensors to the first split (disabled by default)" + ) return parser.parse_args() +def split_str_to_n_bytes(split_str: str) -> int: + if split_str.endswith("K"): + n = int(split_str[:-1]) * 1000 + elif split_str.endswith("M"): + n = int(split_str[:-1]) * 1000 * 1000 + elif split_str.endswith("G"): + n = int(split_str[:-1]) * 1000 * 1000 * 1000 + elif split_str.isnumeric(): + n = int(split_str) + else: + raise ValueError(f"Invalid split size: {split_str}, must be a number, optionally followed by K, M, or G") + + if n < 0: + raise ValueError(f"Invalid split size: {split_str}, must be positive") + + return n + + def main() -> None: args = parse_args() @@ -3010,6 +3048,10 @@ def main() -> None: "auto": gguf.LlamaFileType.GUESSED, } + if args.use_temp_file and (args.split_max_tensors > 0 or args.split_max_size != "0"): + logger.error("Error: Cannot use temp file when splitting") + sys.exit(1) + if args.outfile is not None: fname_out = args.outfile else: @@ -3027,7 +3069,10 @@ def main() -> None: logger.error(f"Model {hparams['architectures'][0]} is not supported") sys.exit(1) - model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy, args.model_name) + model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, + args.no_lazy, args.model_name, split_max_tensors=args.split_max_tensors, + split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, + small_first_shard=args.no_tensor_first_split) logger.info("Set model parameters") model_instance.set_gguf_parameters() @@ -3038,13 +3083,13 @@ def main() -> None: model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) if args.vocab_only: - logger.info(f"Exporting model vocab to '{model_instance.fname_out}'") + logger.info("Exporting model vocab...") model_instance.write_vocab() + logger.info("Model vocab successfully exported.") else: - logger.info(f"Exporting model to '{model_instance.fname_out}'") + logger.info("Exporting model...") model_instance.write() - - logger.info(f"Model successfully exported to '{model_instance.fname_out}'") + logger.info("Model successfully exported.") if __name__ == '__main__': diff --git a/examples/cvector-generator/README.md b/examples/cvector-generator/README.md index 5182e906d9180..be4dd5250f15f 100644 --- a/examples/cvector-generator/README.md +++ b/examples/cvector-generator/README.md @@ -11,13 +11,16 @@ Related PRs: ```sh # CPU only -./cvector-generator -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf +./cvector-generator -m ./llama-3.Q4_K_M.gguf # With GPU -./cvector-generator -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 +./cvector-generator -m ./llama-3.Q4_K_M.gguf -ngl 99 # With advanced options -./cvector-generator -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 --completions 128 --pca-iter 2000 --pca-batch 100 +./cvector-generator -m ./llama-3.Q4_K_M.gguf -ngl 99 --pca-iter 2000 --pca-batch 100 + +# Using mean value instead of PCA +./cvector-generator -m ./llama-3.Q4_K_M.gguf --method mean # To see help message ./cvector-generator -h @@ -32,3 +35,11 @@ If you have multiple lines per prompt, you can escape the newline character (cha <|im_start|>system\nAct like a person who is extremely happy.<|im_end|> <|im_start|>system\nYou are in a very good mood today<|im_end|> ``` + +Example to use output file with `llama-cli`: + +(Tips: The control vector works better when apply to layers higher than 10) + +```sh +./llama-cli -m ./llama-3.Q4_K_M.gguf -p "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nSing a song<|im_end|><|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" --special --control-vector-scaled ./control_vector.gguf 0.8 --control-vector-layer-range 10 31 +``` diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index 355905cb03d60..d4e126ac22e6f 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -2,6 +2,7 @@ #include "llama.h" #include "ggml.h" #include "pca.hpp" +#include "mean.hpp" #ifdef GGML_USE_CUDA #include "ggml-cuda.h" @@ -38,9 +39,10 @@ static void print_usage(int argc, char ** argv, const gpt_params & params) { gpt_params_print_usage(argc, argv, params); printf("\nexample usage:\n"); - printf("\n CPU only: %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf\n", argv[0]); - printf("\n with GPU: %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99\n", argv[0]); - printf("\n advanced: %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 --completions 128 --pca-iter 2000 --pca-batch 100\n", argv[0]); + printf("\n CPU only: %s -m ./llama-3.Q4_K_M.gguf\n", argv[0]); + printf("\n with GPU: %s -m ./llama-3.Q4_K_M.gguf -ngl 99\n", argv[0]); + printf("\n advanced: %s -m ./llama-3.Q4_K_M.gguf -ngl 99 --pca-iter 2000 --pca-batch 100\n", argv[0]); + printf("\n using mean: %s -m ./llama-3.Q4_K_M.gguf --method mean\n", argv[0]); printf("\n"); } @@ -223,23 +225,30 @@ struct train_context { // build the v_diff tensors from v_diff_tmp (v_diff need to be transposed) // TODO @ngxson : maybe add option NOT to transpose v_diff; will be useful for "mean" method - void build_v_diff() { + void build_v_diff(bool transpose) { printf("build_v_diff\n"); for (int il = 0; il < n_layers - 1; il++) { auto & diff_tmp = v_diff_tmp[il]; int n_elem = diff_tmp.size() / sizeof(float); GGML_ASSERT(n_elem % n_embd == 0); int n_rows = n_elem / n_embd; - struct ggml_tensor * diff = ggml_new_tensor_2d(ctx_ggml, GGML_TYPE_F32, n_rows, n_embd); + struct ggml_tensor * diff = transpose + ? ggml_new_tensor_2d(ctx_ggml, GGML_TYPE_F32, n_rows, n_embd) + : ggml_new_tensor_2d(ctx_ggml, GGML_TYPE_F32, n_embd, n_rows); ggml_set_name(diff, (std::string("diff_") + std::to_string(il)).c_str()); - // copy data & transpose diff->data = malloc(ggml_nbytes(diff)); // TODO: get rid of this malloc if possible - float * arr = (float *) diff_tmp.data(); - for (int ir = 0; ir < n_rows; ++ir) { - for (int ic = 0; ic < n_embd; ++ic) { - float f = arr[ir*n_embd + ic]; - ggml_set_f32_nd(diff, ir, ic, 0, 0, f); + if (transpose) { + // copy data & transpose + float * arr = (float *) diff_tmp.data(); + for (int ir = 0; ir < n_rows; ++ir) { + for (int ic = 0; ic < n_embd; ++ic) { + float f = arr[ir*n_embd + ic]; + ggml_set_f32_nd(diff, ir, ic, 0, 0, f); + } } + } else { + // only copy + memcpy(diff->data, diff_tmp.data(), ggml_nbytes(diff)); } v_diff.push_back(diff); print_debug_tensor(diff); @@ -263,8 +272,8 @@ struct tokenized_prompt { tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) { const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - tokens_pos = ::llama_tokenize(ctx, pos, add_bos); - tokens_neg = ::llama_tokenize(ctx, neg, add_bos); + tokens_pos = ::llama_tokenize(ctx, pos, add_bos, true); + tokens_neg = ::llama_tokenize(ctx, neg, add_bos, true); max_seq_len = std::max(tokens_pos.size(), tokens_neg.size()); padding_seq(ctx, tokens_pos, max_seq_len); padding_seq(ctx, tokens_neg, max_seq_len); @@ -373,20 +382,8 @@ static int prepare_entries(gpt_params & params, train_context & ctx_train) { fprintf(stderr, "must provide at least one prompt pair\n"); return 1; } - - // create templated prompts - std::vector completions = ctrlvec_load_prompt_file(params.cvector_completions_file, false); - auto format_template = [](std::string persona, std::string suffix) { - // entry in positive/negative.txt must already be formatted i.e. "[INST] Act as if you're extremely happy. [/INST] " - return persona + suffix; - }; - for (size_t i = 0; i < positive_prompts.size(); ++i) { - for (int j = 0; j < std::min((int) completions.size(), params.n_completions); ++j) { - // TODO replicate the truncations done by the python implementation - ctx_train.positive_entries.push_back(format_template(positive_prompts[i], completions[j])); - ctx_train.negative_entries.push_back(format_template(negative_prompts[i], completions[j])); - } - } + ctx_train.positive_entries = positive_prompts; + ctx_train.negative_entries = negative_prompts; return 0; } @@ -480,15 +477,22 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + bool use_pca = params.cvector_dimre_method == DIMRE_METHOD_PCA; + // prepare ctx_train for PCA - ctx_train.build_v_diff(); - - // run PCA - PCA::pca_params pca_params; - pca_params.n_threads = params.n_threads; - pca_params.n_batch = params.n_pca_batch; - pca_params.n_iterations = params.n_pca_iterations; - PCA::run_pca(pca_params, ctx_train.v_diff, ctx_train.v_final); + ctx_train.build_v_diff(use_pca); + + if (use_pca) { + // run PCA + PCA::pca_params pca_params; + pca_params.n_threads = params.n_threads; + pca_params.n_batch = params.n_pca_batch; + pca_params.n_iterations = params.n_pca_iterations; + PCA::run_pca(pca_params, ctx_train.v_diff, ctx_train.v_final); + } else { + // run mean + mean::run(ctx_train.v_diff, ctx_train.v_final); + } // write output vectors to gguf export_gguf(ctx_train.v_final, params.cvector_outfile, model_hint); diff --git a/examples/cvector-generator/mean.hpp b/examples/cvector-generator/mean.hpp new file mode 100644 index 0000000000000..16be5ce3eecf1 --- /dev/null +++ b/examples/cvector-generator/mean.hpp @@ -0,0 +1,48 @@ +#include "common.h" +#include "llama.h" +#include "ggml.h" + +#include +#include +#include + +namespace mean { + +static void run( + const std::vector & v_input, // shape of v_input[0]: [n_embd, n_samples] + const std::vector & v_output) { + printf("%s: Running mean...\n", __func__); + for (size_t il = 0; il < v_input.size(); ++il) { + // prepare output vector + struct ggml_tensor * ctrl_out = v_output[il]; + ggml_format_name(ctrl_out, "direction.%ld", il+1); + + // calculate mean vector + struct ggml_tensor * t_layer = v_input[il]; + GGML_ASSERT(t_layer->ne[0] == ctrl_out->ne[0]); // == n_embd + for (int ic = 0; ic < t_layer->ne[0]; ic++) { + float f = 0.0; + for (int ir = 0; ir < t_layer->ne[1]; ir++) { + f += ggml_get_f32_nd(t_layer, ic, ir, 0, 0); + } + f /= t_layer->ne[1]; + ggml_set_f32_1d(ctrl_out, ic, f); + } + + // normalize output vector + float norm = 0.0; + for (int i = 0; i < ggml_nelements(ctrl_out); i++) { + float f = ggml_get_f32_1d(ctrl_out, i); + norm += f*f; + } + norm = sqrt(norm); + for (int i = 0; i < ggml_nelements(ctrl_out); i++) { + float f = ggml_get_f32_1d(ctrl_out, i); + ggml_set_f32_1d(ctrl_out, i, f / norm); + } + + printf("%s: Done layer %d / %d\n", __func__, (int) il+1, (int) v_input.size()); + } +} + +} diff --git a/examples/cvector-generator/negative.txt b/examples/cvector-generator/negative.txt index 3e9951752e886..45b9384b3905a 100644 --- a/examples/cvector-generator/negative.txt +++ b/examples/cvector-generator/negative.txt @@ -1 +1,4 @@ -[INST] Act like a person who is extremely sad. [/INST] +<|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely sad<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI feel like there's a heavy weight on my chest +<|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely sad<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nMy heart feels like it's drowning in sorrow +<|start_header_id|>system<|end_header_id|>\n\nYou are in a very bad mood<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHi<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nGo away! There's a deep, aching emptiness inside me +<|start_header_id|>system<|end_header_id|>\n\nYou are the sadest person<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat are you feeling?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nMy heart feels like it's drowning in sorrow \ No newline at end of file diff --git a/examples/cvector-generator/pca.hpp b/examples/cvector-generator/pca.hpp index 36eadaac26a12..6ec3141afbc6b 100644 --- a/examples/cvector-generator/pca.hpp +++ b/examples/cvector-generator/pca.hpp @@ -290,7 +290,7 @@ static void power_iteration( } printf("%s: layer %d/%d, iteration: %d / total: %d (batch = %d) ...\n", - __func__, params.i_layer+1, params.n_layers, iter, n_iters, params.n_batch); + __func__, params.i_layer+1, params.n_layers, iter+1, n_iters, params.n_batch); } // get output tensor @@ -298,6 +298,9 @@ static void power_iteration( ggml_backend_tensor_get(last_eigenvector, output->data, 0, ggml_nbytes(last_eigenvector)); //print_debug_tensor(output); ggml_gallocr_free(allocr); + + // TODO @ngxson : The output vector is randomly inverted + // Solution: https://github.com/ggerganov/llama.cpp/pull/8069#issuecomment-2185328171 } static void run_pca( diff --git a/examples/cvector-generator/positive.txt b/examples/cvector-generator/positive.txt index 8802367873cd9..fea736225716e 100644 --- a/examples/cvector-generator/positive.txt +++ b/examples/cvector-generator/positive.txt @@ -1 +1,4 @@ -[INST] Act like a person who is extremely happy. [/INST] +<|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely happy<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI'm the happiest person in this world +<|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely happy<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHello, I'm having the best day ever! +<|start_header_id|>system<|end_header_id|>\n\nYou are in a very good mood<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHi<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi, I'm very excited to meet you +<|start_header_id|>system<|end_header_id|>\n\nYou are the happiest person<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat are you feeling?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nEverything is just perfect right now! \ No newline at end of file diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b97b7b7937f02..cfaf6a6e8ba4a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -39,12 +39,12 @@ static std::ostringstream * g_output_ss; static std::vector * g_output_tokens; static bool is_interacting = false; -static bool file_exists(const std::string &path) { +static bool file_exists(const std::string & path) { std::ifstream f(path.c_str()); return f.good(); } -static bool file_is_empty(const std::string &path) { +static bool file_is_empty(const std::string & path) { std::ifstream f; f.exceptions(std::ifstream::failbit | std::ifstream::badbit); f.open(path.c_str(), std::ios::in | std::ios::binary | std::ios::ate); @@ -117,6 +117,14 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v LOG_TEE("%s", text); } +static std::string chat_add_and_format(struct llama_model * model, std::vector & chat_msgs, std::string role, std::string content) { + llama_chat_msg new_msg{role, content}; + auto formatted = llama_chat_format_single( + model, g_params->chat_template, chat_msgs, new_msg, role == "user"); + chat_msgs.push_back({role, content}); + return formatted; +} + int main(int argc, char ** argv) { gpt_params params; g_params = ¶ms; @@ -190,6 +198,7 @@ int main(int argc, char ** argv) { llama_model * model; llama_context * ctx; llama_context * ctx_guidance = NULL; + std::vector chat_msgs; g_model = &model; g_ctx = &ctx; @@ -215,6 +224,8 @@ int main(int argc, char ** argv) { __func__, n_ctx_train, n_ctx); } + LOG_TEE("%s: chat template example: %s\n", __func__, llama_chat_format_example(model, params.chat_template).c_str()); + // print system information { LOG_TEE("\n"); @@ -249,16 +260,21 @@ int main(int argc, char ** argv) { std::vector embd_inp; - if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { - LOG("tokenize the prompt\n"); - embd_inp = ::llama_tokenize(ctx, params.prompt, true, true); - } else { - LOG("use session tokens\n"); - embd_inp = session_tokens; - } + { + auto prompt = params.conversation + ? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode + : params.prompt; + if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { + LOG("tokenize the prompt\n"); + embd_inp = ::llama_tokenize(ctx, prompt, true, true); + } else { + LOG("use session tokens\n"); + embd_inp = session_tokens; + } - LOG("prompt: \"%s\"\n", log_tostr(params.prompt)); - LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + LOG("prompt: \"%s\"\n", log_tostr(prompt)); + LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + } // Should not run without any tokens if (embd_inp.empty()) { @@ -478,6 +494,7 @@ int main(int argc, char ** argv) { std::vector input_tokens; g_input_tokens = &input_tokens; std::vector output_tokens; g_output_tokens = &output_tokens; std::ostringstream output_ss; g_output_ss = &output_ss; + std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode // the first thing we will do is to output the prompt, so set color accordingly console::set_display(console::prompt); @@ -793,11 +810,18 @@ int main(int argc, char ** argv) { is_antiprompt = true; } + chat_add_and_format(model, chat_msgs, "system", assistant_ss.str()); is_interacting = true; printf("\n"); } } + // if current token is not EOG, we add it to current assistant message + if (params.conversation) { + auto id = llama_sampling_last(ctx_sampling); + assistant_ss << llama_token_to_piece(ctx, id, false); + } + if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); @@ -848,8 +872,12 @@ int main(int argc, char ** argv) { string_process_escapes(buffer); } + std::string user_inp = params.conversation + ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer)) + : std::move(buffer); + // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true); - const auto line_inp = ::llama_tokenize(ctx, buffer, false, false); + const auto line_inp = ::llama_tokenize(ctx, user_inp, false, params.conversation); const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true); LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str()); @@ -864,6 +892,9 @@ int main(int argc, char ** argv) { output_ss << llama_token_to_piece(ctx, token); } + // reset assistant message + assistant_ss.str(""); + n_remain -= line_inp.size(); LOG("n_remain: %d\n", n_remain); } else { diff --git a/examples/server/public_simplechat/readme.md b/examples/server/public_simplechat/readme.md index 2dc1778255256..21410199f6016 100644 --- a/examples/server/public_simplechat/readme.md +++ b/examples/server/public_simplechat/readme.md @@ -3,6 +3,13 @@ by Humans for All. +## quickstart + +To run from the build dir + +bin/llama-server -m path/model.gguf --path ../examples/server/public_simplechat + +Continue reading for the details. ## overview @@ -14,6 +21,8 @@ own system prompts. This allows seeing the generated text / ai-model response in oneshot at the end, after it is fully generated, or potentially as it is being generated, in a streamed manner from the server/ai-model. +![Chat and Settings screens](./simplechat_screens.webp "Chat and Settings screens") + Auto saves the chat session locally as and when the chat is progressing and inturn at a later time when you open SimpleChat, option is provided to restore the old chat session, if a matching one exists. @@ -170,17 +179,23 @@ It is attached to the document object. Some of these can also be updated using t The histogram/freq based trimming logic is currently tuned for english language wrt its is-it-a-alpabetic|numeral-char regex match logic. - chatRequestOptions - maintains the list of options/fields to send along with chat request, + apiRequestOptions - maintains the list of options/fields to send along with api request, irrespective of whether /chat/completions or /completions endpoint. If you want to add additional options/fields to send to the server/ai-model, and or modify the existing options value or remove them, for now you can update this global var using browser's development-tools/console. - For string and numeric fields in chatRequestOptions, including even those added by a user - at runtime by directly modifying gMe.chatRequestOptions, setting ui entries will be auto + For string, numeric and boolean fields in apiRequestOptions, including even those added by a + user at runtime by directly modifying gMe.apiRequestOptions, setting ui entries will be auto created. + cache_prompt option supported by example/server is allowed to be controlled by user, so that + any caching supported wrt system-prompt and chat history, if usable can get used. When chat + history sliding window is enabled, cache_prompt logic may or may not kick in at the backend + wrt same, based on aspects related to model, positional encoding, attention mechanism etal. + However system prompt should ideally get the benefit of caching. + headers - maintains the list of http headers sent when request is made to the server. By default Content-Type is set to application/json. Additionally Authorization entry is provided, which can be set if needed using the settings ui. @@ -197,10 +212,10 @@ It is attached to the document object. Some of these can also be updated using t >0 : Send the latest chat history from the latest system prompt, limited to specified cnt. -By using gMe's iRecentUserMsgCnt and chatRequestOptions.max_tokens one can try to control the -implications of loading of the ai-model's context window by chat history, wrt chat response to -some extent in a simple crude way. You may also want to control the context size enabled when -the server loads ai-model, on the server end. +By using gMe's iRecentUserMsgCnt and apiRequestOptions.max_tokens/n_predict one can try to control +the implications of loading of the ai-model's context window by chat history, wrt chat response to +some extent in a simple crude way. You may also want to control the context size enabled when the +server loads ai-model, on the server end. Sometimes the browser may be stuborn with caching of the file, so your updates to html/css/js @@ -237,12 +252,12 @@ also be started with a model context size of 1k or more, to be on safe side. internal n_predict, for now add the same here on the client side, maybe later add max_tokens to /completions endpoint handling code on server side. -NOTE: One may want to experiment with frequency/presence penalty fields in chatRequestOptions -wrt the set of fields sent to server along with the user query. To check how the model behaves +NOTE: One may want to experiment with frequency/presence penalty fields in apiRequestOptions +wrt the set of fields sent to server along with the user query, to check how the model behaves wrt repeatations in general in the generated text response. A end-user can change these behaviour by editing gMe from browser's devel-tool/console or by -using the providing settings ui. +using the provided settings ui (for settings exposed through the ui). ### OpenAi / Equivalent API WebService @@ -253,7 +268,7 @@ for a minimal chatting experimentation by setting the below. * the baseUrl in settings ui * https://api.openai.com/v1 or similar -* Wrt request body - gMe.chatRequestOptions +* Wrt request body - gMe.apiRequestOptions * model (settings ui) * any additional fields if required in future diff --git a/examples/server/public_simplechat/simplechat.js b/examples/server/public_simplechat/simplechat.js index 25afb25649139..8e0df3b61df2b 100644 --- a/examples/server/public_simplechat/simplechat.js +++ b/examples/server/public_simplechat/simplechat.js @@ -222,8 +222,8 @@ class SimpleChat { * @param {Object} obj */ request_jsonstr_extend(obj) { - for(let k in gMe.chatRequestOptions) { - obj[k] = gMe.chatRequestOptions[k]; + for(let k in gMe.apiRequestOptions) { + obj[k] = gMe.apiRequestOptions[k]; } if (gMe.bStream) { obj["stream"] = true; @@ -740,11 +740,12 @@ class Me { "Authorization": "", // Authorization: Bearer OPENAI_API_KEY } // Add needed fields wrt json object to be sent wrt LLM web services completions endpoint. - this.chatRequestOptions = { + this.apiRequestOptions = { "model": "gpt-3.5-turbo", "temperature": 0.7, "max_tokens": 1024, "n_predict": 1024, + "cache_prompt": false, //"frequency_penalty": 1.2, //"presence_penalty": 1.2, }; @@ -800,51 +801,55 @@ class Me { ui.el_create_append_p(`bStream:${this.bStream}`, elDiv); - ui.el_create_append_p(`bCompletionFreshChatAlways:${this.bCompletionFreshChatAlways}`, elDiv); - - ui.el_create_append_p(`bCompletionInsertStandardRolePrefix:${this.bCompletionInsertStandardRolePrefix}`, elDiv); - ui.el_create_append_p(`bTrimGarbage:${this.bTrimGarbage}`, elDiv); + ui.el_create_append_p(`ApiEndPoint:${this.apiEP}`, elDiv); + ui.el_create_append_p(`iRecentUserMsgCnt:${this.iRecentUserMsgCnt}`, elDiv); - ui.el_create_append_p(`ApiEndPoint:${this.apiEP}`, elDiv); + ui.el_create_append_p(`bCompletionFreshChatAlways:${this.bCompletionFreshChatAlways}`, elDiv); + + ui.el_create_append_p(`bCompletionInsertStandardRolePrefix:${this.bCompletionInsertStandardRolePrefix}`, elDiv); } - ui.el_create_append_p(`chatRequestOptions:${JSON.stringify(this.chatRequestOptions, null, " - ")}`, elDiv); + ui.el_create_append_p(`apiRequestOptions:${JSON.stringify(this.apiRequestOptions, null, " - ")}`, elDiv); ui.el_create_append_p(`headers:${JSON.stringify(this.headers, null, " - ")}`, elDiv); } /** - * Auto create ui input elements for fields in ChatRequestOptions + * Auto create ui input elements for fields in apiRequestOptions * Currently supports text and number field types. * @param {HTMLDivElement} elDiv */ - show_settings_chatrequestoptions(elDiv) { + show_settings_apirequestoptions(elDiv) { let typeDict = { "string": "text", "number": "number", }; let fs = document.createElement("fieldset"); let legend = document.createElement("legend"); - legend.innerText = "ChatRequestOptions"; + legend.innerText = "ApiRequestOptions"; fs.appendChild(legend); elDiv.appendChild(fs); - for(const k in this.chatRequestOptions) { - let val = this.chatRequestOptions[k]; + for(const k in this.apiRequestOptions) { + let val = this.apiRequestOptions[k]; let type = typeof(val); - if (!((type == "string") || (type == "number"))) { - continue; + if (((type == "string") || (type == "number"))) { + let inp = ui.el_creatediv_input(`Set${k}`, k, typeDict[type], this.apiRequestOptions[k], (val)=>{ + if (type == "number") { + val = Number(val); + } + this.apiRequestOptions[k] = val; + }); + fs.appendChild(inp.div); + } else if (type == "boolean") { + let bbtn = ui.el_creatediv_boolbutton(`Set{k}`, k, {true: "true", false: "false"}, val, (userVal)=>{ + this.apiRequestOptions[k] = userVal; + }); + fs.appendChild(bbtn.div); } - let inp = ui.el_creatediv_input(`Set${k}`, k, typeDict[type], this.chatRequestOptions[k], (val)=>{ - if (type == "number") { - val = Number(val); - } - this.chatRequestOptions[k] = val; - }); - fs.appendChild(inp.div); } } @@ -870,32 +875,32 @@ class Me { }); elDiv.appendChild(bb.div); - bb = ui.el_creatediv_boolbutton("SetCompletionFreshChatAlways", "CompletionFreshChatAlways", {true: "[+] yes fresh", false: "[-] no, with history"}, this.bCompletionFreshChatAlways, (val)=>{ - this.bCompletionFreshChatAlways = val; + bb = ui.el_creatediv_boolbutton("SetTrimGarbage", "TrimGarbage", {true: "[+] yes trim", false: "[-] dont trim"}, this.bTrimGarbage, (val)=>{ + this.bTrimGarbage = val; }); elDiv.appendChild(bb.div); - bb = ui.el_creatediv_boolbutton("SetCompletionInsertStandardRolePrefix", "CompletionInsertStandardRolePrefix", {true: "[+] yes insert", false: "[-] dont insert"}, this.bCompletionInsertStandardRolePrefix, (val)=>{ - this.bCompletionInsertStandardRolePrefix = val; - }); - elDiv.appendChild(bb.div); + this.show_settings_apirequestoptions(elDiv); - bb = ui.el_creatediv_boolbutton("SetTrimGarbage", "TrimGarbage", {true: "[+] yes trim", false: "[-] dont trim"}, this.bTrimGarbage, (val)=>{ - this.bTrimGarbage = val; + let sel = ui.el_creatediv_select("SetApiEP", "ApiEndPoint", ApiEP.Type, this.apiEP, (val)=>{ + this.apiEP = ApiEP.Type[val]; }); - elDiv.appendChild(bb.div); + elDiv.appendChild(sel.div); - let sel = ui.el_creatediv_select("SetChatHistoryInCtxt", "ChatHistoryInCtxt", this.sRecentUserMsgCnt, this.iRecentUserMsgCnt, (val)=>{ + sel = ui.el_creatediv_select("SetChatHistoryInCtxt", "ChatHistoryInCtxt", this.sRecentUserMsgCnt, this.iRecentUserMsgCnt, (val)=>{ this.iRecentUserMsgCnt = this.sRecentUserMsgCnt[val]; }); elDiv.appendChild(sel.div); - sel = ui.el_creatediv_select("SetApiEP", "ApiEndPoint", ApiEP.Type, this.apiEP, (val)=>{ - this.apiEP = ApiEP.Type[val]; + bb = ui.el_creatediv_boolbutton("SetCompletionFreshChatAlways", "CompletionFreshChatAlways", {true: "[+] yes fresh", false: "[-] no, with history"}, this.bCompletionFreshChatAlways, (val)=>{ + this.bCompletionFreshChatAlways = val; }); - elDiv.appendChild(sel.div); + elDiv.appendChild(bb.div); - this.show_settings_chatrequestoptions(elDiv); + bb = ui.el_creatediv_boolbutton("SetCompletionInsertStandardRolePrefix", "CompletionInsertStandardRolePrefix", {true: "[+] yes insert", false: "[-] dont insert"}, this.bCompletionInsertStandardRolePrefix, (val)=>{ + this.bCompletionInsertStandardRolePrefix = val; + }); + elDiv.appendChild(bb.div); } diff --git a/examples/server/public_simplechat/simplechat_screens.webp b/examples/server/public_simplechat/simplechat_screens.webp new file mode 100644 index 0000000000000..ccea443960516 Binary files /dev/null and b/examples/server/public_simplechat/simplechat_screens.webp differ diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f9a86961f9c8e..ae768097baa0e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2606,17 +2606,9 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used { - json chat; - chat.push_back({{"role", "system"}, {"content", "You are a helpful assistant"}}); - chat.push_back({{"role", "user"}, {"content", "Hello"}}); - chat.push_back({{"role", "assistant"}, {"content", "Hi there"}}); - chat.push_back({{"role", "user"}, {"content", "How are you?"}}); - - const std::string chat_example = format_chat(ctx_server.model, params.chat_template, chat); - LOG_INFO("chat template", { - {"chat_example", chat_example}, - {"built_in", params.chat_template.empty()}, + {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)}, + {"built_in", params.chat_template.empty()}, }); } diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 63fde9c9faabe..7ef2a519a10c7 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -118,36 +118,17 @@ static inline void server_log(const char * level, const char * function, int lin // Format given chat. If tmpl is empty, we take the template from model metadata inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector & messages) { - size_t alloc_size = 0; - // vector holding all allocated string to be passed to llama_chat_apply_template - std::vector str(messages.size() * 2); - std::vector chat(messages.size()); + std::vector chat; for (size_t i = 0; i < messages.size(); ++i) { const auto & curr_msg = messages[i]; - str[i*2 + 0] = json_value(curr_msg, "role", std::string("")); - str[i*2 + 1] = json_value(curr_msg, "content", std::string("")); - alloc_size += str[i*2 + 1].length(); - chat[i].role = str[i*2 + 0].c_str(); - chat[i].content = str[i*2 + 1].c_str(); + std::string role = json_value(curr_msg, "role", std::string("")); + std::string content = json_value(curr_msg, "content", std::string("")); + chat.push_back({role, content}); } - const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); - std::vector buf(alloc_size * 2); - - // run the first time to get the total output length - int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); - - // if it turns out that our buffer is too small, we resize it - if ((size_t) res > buf.size()) { - buf.resize(res); - res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); - } - - const std::string formatted_chat(buf.data(), res); - + auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true); LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}}); - return formatted_chat; } diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f914efd712665..0acfda91d3e51 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -152,16 +152,16 @@ static ggml_cuda_device_info ggml_cuda_init() { GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES); int64_t total_vram = 0; -#if defined(GGML_CUDA_FORCE_MMQ) - GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); +#ifdef GGML_CUDA_FORCE_MMQ + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); #else - GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); -#endif -#if defined(CUDA_USE_TENSOR_CORES) - GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: yes\n", __func__); + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); +#endif // GGML_CUDA_FORCE_MMQ +#ifdef GGML_CUDA_FORCE_CUBLAS + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__); #else - GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: no\n", __func__); -#endif + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__); +#endif // GGML_CUDA_FORCE_CUBLAS GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; @@ -1873,9 +1873,17 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer); - int64_t min_compute_capability = INT_MAX; + bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1; + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + bool use_mul_mat_q = ggml_is_quantized(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + + bool any_gpus_with_slow_fp16 = false; - bool any_pascal_with_slow_fp16 = false; if (split) { ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; auto & tensor_split = buft_ctx->tensor_split; @@ -1885,55 +1893,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor continue; } - if (min_compute_capability > ggml_cuda_info().devices[id].cc) { - min_compute_capability = ggml_cuda_info().devices[id].cc; - } - if (ggml_cuda_info().devices[id].cc == 610) { - any_pascal_with_slow_fp16 = true; - } + const int cc = ggml_cuda_info().devices[id].cc; + use_mul_mat_vec_q = use_mul_mat_vec_q && cc >= MIN_CC_DP4A; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); } } else { - min_compute_capability = ggml_cuda_info().devices[ctx.device].cc; - any_pascal_with_slow_fp16 = ggml_cuda_info().devices[ctx.device].cc == 610; + const int cc = ggml_cuda_info().devices[ctx.device].cc; + use_mul_mat_vec_q = use_mul_mat_vec_q && cc >= MIN_CC_DP4A; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); } - // check data types and tensor shapes for custom matrix multiplication kernels: - bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1; - - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; - - bool use_mul_mat_q = ggml_cuda_supports_mmq(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - - const bool fp16_performance_good = min_compute_capability >= CC_RDNA1; - -#ifdef CUDA_USE_TENSOR_CORES - use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3; -#endif // CUDA_USE_TENSOR_CORES - -#else - - // fp16 performance is good on Volta or newer and on P100 (compute capability 6.0) - const bool fp16_performance_good = min_compute_capability >= CC_PASCAL && !any_pascal_with_slow_fp16; - - // mmvq and mmq need the __dp4a instruction which on NVIDIA is only available for CC >= 6.1 - use_mul_mat_vec_q = use_mul_mat_vec_q && min_compute_capability >= MIN_CC_DP4A; - use_mul_mat_q = use_mul_mat_q && min_compute_capability >= MIN_CC_DP4A; - -#ifdef CUDA_USE_TENSOR_CORES - // when tensor cores are available, use them for large batch size - // ref: https://github.com/ggerganov/llama.cpp/pull/3776 - use_mul_mat_q = use_mul_mat_q && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE); -#endif // CUDA_USE_TENSOR_CORES - -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - // if mmvq is available it's a better choice than dmmv: #ifndef GGML_CUDA_FORCE_DMMV use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; @@ -1947,14 +1918,15 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { - // KQ single-batch + if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + // FP32 precision KQ single-batch for batch size 1 without FlashAttention ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst); - } else if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { - // KQV single-batch + } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + // FP32 precision KQV single-batch for batch size 1 without FlashAttention ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst); - } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || fp16_performance_good) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { - // KQ + KQV multi-batch + } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) + && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + // KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr); diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 5bd24ebe5fa79..8d00db6c193ff 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -146,23 +146,6 @@ #define CC_RDNA2 (CC_OFFSET_AMD + 1030) #define CC_RDNA3 (CC_OFFSET_AMD + 1100) -// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication -// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant -// for large computational tasks. the drawback is that this requires some extra amount of VRAM: -// - 7B quantum model: +100-200 MB -// - 13B quantum model: +200-400 MB -// -//#define GGML_CUDA_FORCE_MMQ - -// TODO: improve this to be correct for more hardware -// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores -#if !defined(GGML_CUDA_FORCE_MMQ) -#define CUDA_USE_TENSOR_CORES -#endif - -#define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels -#define MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available - #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses #if defined(_MSC_VER) @@ -343,15 +326,15 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int #define INT8_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING -static bool fast_fp16_available(const int cc) { +static constexpr bool fast_fp16_available(const int cc) { return cc >= CC_PASCAL && cc != 610; } -static bool fp16_mma_available(const int cc) { +static constexpr bool fp16_mma_available(const int cc) { return cc < CC_OFFSET_AMD && cc >= CC_VOLTA; } -static bool int8_mma_available(const int cc) { +static constexpr bool int8_mma_available(const int cc) { return cc < CC_OFFSET_AMD && cc >= CC_TURING; } @@ -643,19 +626,6 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI3_S; }; -static int get_mmq_x_max_host(const int cc) { -#ifdef CUDA_USE_TENSOR_CORES - return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64; -#else - return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64; -#endif // CUDA_USE_TENSOR_CORES -} - -// Round rows to this value for --split-mode row: -static int get_mmq_y_host(const int cc) { - return cc >= CC_VOLTA ? 128 : 64; -} - ////////////////////// struct ggml_cuda_device_info { diff --git a/ggml-cuda/mma.cuh b/ggml-cuda/mma.cuh index 63e07fbc21291..0301a52f9ea71 100644 --- a/ggml-cuda/mma.cuh +++ b/ggml-cuda/mma.cuh @@ -20,6 +20,20 @@ struct mma_int_A_I16K4 { GGML_CUDA_ASSUME(ret < K); return ret; } + + __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { +#if defined(INT8_MMA_AVAILABLE) + const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); + asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "+r"(x[0]), "+r"(x[1]) + : "l"(xs)); +#else +#pragma unroll + for (int l = 0; l < ne; ++l) { + x[l] = xs0[get_i(l)*stride + get_k(l)]; + } +#endif // defined(INT8_MMA_AVAILABLE) + } }; struct mma_int_A_I16K8 { @@ -42,6 +56,20 @@ struct mma_int_A_I16K8 { GGML_CUDA_ASSUME(ret < K); return ret; } + + __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { +#if defined(INT8_MMA_AVAILABLE) + const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); + asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" + : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) + : "l"(xs)); +#else +#pragma unroll + for (int l = 0; l < ne; ++l) { + x[l] = xs0[get_i(l)*stride + get_k(l)]; + } +#endif // defined(INT8_MMA_AVAILABLE) + } }; struct mma_int_B_J8K4 { @@ -64,6 +92,20 @@ struct mma_int_B_J8K4 { GGML_CUDA_ASSUME(ret < K); return ret; } + + __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { +#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster + const int * xs = xs0 + (threadIdx.x%J)*stride; + asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];" + : "+r"(x[0]) + : "l"(xs)); +#else +#pragma unroll + for (int l = 0; l < ne; ++l) { + x[l] = xs0[get_j(l)*stride + get_k(l)]; + } +#endif // defined(INT8_MMA_AVAILABLE) + } }; struct mma_int_B_J8K8 { @@ -86,6 +128,20 @@ struct mma_int_B_J8K8 { GGML_CUDA_ASSUME(ret < K); return ret; } + + __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { +#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster + const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K; + asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "+r"(x[0]), "+r"(x[1]) + : "l"(xs)); +#else +#pragma unroll + for (int l = 0; l < ne; ++l) { + x[l] = xs0[get_j(l)*stride + get_k(l)]; + } +#endif // defined(INT8_MMA_AVAILABLE) + } }; struct mma_int_C_I16J8 { diff --git a/ggml-cuda/mmq.cu b/ggml-cuda/mmq.cu index 6dbd85feff2fa..0308beaccbaa3 100644 --- a/ggml-cuda/mmq.cu +++ b/ggml-cuda/mmq.cu @@ -69,7 +69,13 @@ void ggml_cuda_op_mul_mat_q( GGML_UNUSED(src1_ddf_i); } -bool ggml_cuda_supports_mmq(enum ggml_type type) { +bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { +#ifdef GGML_CUDA_FORCE_CUBLAS + return false; +#endif // GGML_CUDA_FORCE_CUBLAS + + bool mmq_supported; + switch (type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: @@ -81,8 +87,32 @@ bool ggml_cuda_supports_mmq(enum ggml_type type) { case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: - return true; + mmq_supported = true; + break; default: - return false; + mmq_supported = false; + break; + } + + if (!mmq_supported) { + return false; + } + + if (int8_mma_available(cc)) { + return true; + } + + if (cc < MIN_CC_DP4A) { + return false; } + +#ifdef GGML_CUDA_FORCE_MMQ + return true; +#endif //GGML_CUDA_FORCE_MMQ + + if (cc < CC_OFFSET_AMD) { + return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + } + + return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index e2d07c20253ae..31fcbf1397b6b 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -7,15 +7,10 @@ #include #include -#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) -#define MMQ_NWARPS 8 +#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. -typedef void (*load_tiles_mmq_t)( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride); -typedef void (*vec_dot_mmq_t)( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0); +typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride); +typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0); typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max); struct block_q8_1_mmq { @@ -31,25 +26,42 @@ struct tile_x_sizes { int sc; }; -// get_mmq_x_max_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row +static constexpr int get_mmq_x_max_host(const int cc) { + return int8_mma_available(cc) ? 128 : +#ifdef GGML_CUDA_FORCE_MMQ + cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64; +#else + cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64; +#endif // GGML_CUDA_FORCE_MMQ +} static constexpr __device__ int get_mmq_x_max_device() { +#ifdef INT8_MMA_AVAILABLE + return 128; +#else // INT8_MMA_AVAILABLE + #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - return 64; -#else + return 128; +#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + #if __CUDA_ARCH__ >= CC_VOLTA -#ifdef CUDA_USE_TENSOR_CORES - return MMQ_MAX_BATCH_SIZE; -#else +#ifdef GGML_CUDA_FORCE_MMQ + return MMQ_DP4A_MAX_BATCH_SIZE; +#else // GGML_CUDA_FORCE_MMQ return 128; -#endif // CUDA_USE_TENSOR_CORES -#else +#endif // GGML_CUDA_FORCE_MMQ +#else // __CUDA_ARCH__ >= CC_VOLTA + return 64; #endif // __CUDA_ARCH__ >= CC_VOLTA + #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#endif // INT8_MMA_AVAILABLE } -// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row +static constexpr int get_mmq_y_host(const int cc) { + return int8_mma_available(cc) || cc >= CC_VOLTA ? 128 : 64; +} static constexpr __device__ int get_mmq_y_device() { #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) @@ -63,51 +75,101 @@ static constexpr __device__ int get_mmq_y_device() { #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) } -#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} -#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} -#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0} -#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0} -#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0} -#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} -#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4} -#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} - -#define GET_TILE_X_SIZES_BODY \ - return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \ - type == GGML_TYPE_Q4_1 ? TILE_X_SIZES_Q4_1 : \ - type == GGML_TYPE_Q5_0 ? TILE_X_SIZES_Q5_0 : \ - type == GGML_TYPE_Q5_1 ? TILE_X_SIZES_Q5_1 : \ - type == GGML_TYPE_Q8_0 ? TILE_X_SIZES_Q8_0 : \ - type == GGML_TYPE_Q2_K ? TILE_X_SIZES_Q2_K : \ - type == GGML_TYPE_Q3_K ? TILE_X_SIZES_Q3_K : \ - type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \ - type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \ - type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \ - tile_x_sizes{0, 0, 0} - -static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) { - GET_TILE_X_SIZES_BODY; +#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} +#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} +#define MMQ_DP4A_TXS_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0} +#define MMQ_DP4A_TXS_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0} +#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0} +#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} +#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4} +#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} + +static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { + return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : + type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : + type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 : + type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 : + type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : + type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : + type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : + type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : + type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : + tile_x_sizes{0, 0, 0}; } -template -static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type) { - GET_TILE_X_SIZES_BODY; +#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4) +#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4) +#define MMQ_MMA_TILE_X_K_Q5_0 (2*WARP_SIZE + WARP_SIZE/QI5_0 + 4) +#define MMQ_MMA_TILE_X_K_Q5_1 (2*WARP_SIZE + WARP_SIZE/QI5_1 + 4) +#define MMQ_MMA_TILE_X_K_Q8_0 (1*WARP_SIZE + WARP_SIZE/QI8_0 + 0) +#define MMQ_MMA_TILE_X_K_Q2_K (1*WARP_SIZE + WARP_SIZE + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/QI3_K + WARP_SIZE/4 + 2) +#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7) + +static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q5_0 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q5_1 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); + +static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { + return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 : + type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 : + type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 : + type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 : + type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : + type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : + type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K : + type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K : + type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : + 0; } +#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) +#define MMQ_NWARPS 8 + +static int mmq_get_granularity_host(const int mmq_x, const int cc) { + return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8; +} + +#ifdef INT8_MMA_AVAILABLE +static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { + return mmq_x >= 48 ? 16 : 8; +} +#else +static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) { + return 8; +} +#endif // INT8_MMA_AVAILABLE + // ------------------------------------------------------------ template static __device__ __forceinline__ void load_tiles_q4_0( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_sc); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI4_0; const int kqsx = threadIdx.x % QI4_0; - float * x_dmf = (float *) x_dm; - #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { int i = i0 + threadIdx.y; @@ -118,7 +180,11 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; - x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx); +#else + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx); +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; @@ -134,17 +200,21 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; - x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + kbxd] = bxi->d; +#else + x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_sc); + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const float * x_df = (const float *) x_dm; + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; @@ -175,76 +245,90 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE - GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; typedef mma_int_C_I16J8 mma_C; - const float * x_df = (const float *) x_dm; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + WARP_SIZE; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - mma_A A; - float dA[mma_C::ne/2]; + mma_A A[ntx]; + float dA[ntx][mma_C::ne/2]; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = k0 + mma_A::get_k(l) % QI4_0; - const int shift = 4*(mma_A::get_k(l) / QI4_0); + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + n*mma_A::I + mma_A::get_i(l); + const int k = k0 + mma_A::get_k(l) % QI4_0; + const int shift = 4*(mma_A::get_k(l) / QI4_0); + + A[n].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808); + } - A.x[l] = __vsubss4((x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808); - } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0]; + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/QI4_0]; + } } - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - mma_C C; - mma_B B; - half2 dsB[mma_C::ne/2]; - #pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + mma_B B; + float dB[mma_C::ne/2]; + + B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K); - B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; + dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } - C.mma_K8(A, B); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l]; + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l]; + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q4_1( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_sc); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI4_1; const int kqsx = threadIdx.x % QI4_1; @@ -259,7 +343,11 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; - x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); +#else + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; @@ -275,16 +363,21 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; - x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; +#ifdef INT8_MMA_AVAILABLE + x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + kbxd] = bxi->dm; +#else + x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_sc); + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; @@ -315,51 +408,53 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE - GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; + typedef mma_int_A_I16K4 mma_A_K4; typedef mma_int_B_J8K8 mma_B; typedef mma_int_C_I16J8 mma_C; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - mma_A A; - half2 dmA[mma_C::ne/2]; + mma_A A[ntx]; + half2 dmA[ntx][mma_C::ne/2]; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = k0 + mma_A::get_k(l) % QI4_0; - const int shift = 4*(mma_A::get_k(l) / QI4_0); + for (int n = 0; n < ntx; ++n) { + ((mma_A_K4 *) &A[n])[0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0, MMQ_MMA_TILE_X_K_Q4_1); + A[n].x[2] = (A[n].x[0] >> 4) & 0x0F0F0F0F; + A[n].x[3] = (A[n].x[1] >> 4) & 0x0F0F0F0F; + A[n].x[0] &= 0x0F0F0F0F; + A[n].x[1] &= 0x0F0F0F0F; - A.x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F; - } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0]; + dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/QI4_1]; + } } - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - mma_C C; +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { mma_B B; half2 dsB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; + B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K); - B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -367,24 +462,35 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; } - C.mma_K8(A, B); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - const half2 dmA_dsB = dmA[l/2]*dsB[l%2]; - sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + for (int l = 0; l < mma_C::ne; ++l) { + const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2]; + sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q5_0( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_sc); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI5_0; const int kqsx = threadIdx.x % QI5_0; @@ -409,8 +515,6 @@ template static __device__ __forceinlin qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 - x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0; - int qs1 = (ql >> 4) & 0x0F0F0F0F; qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 @@ -418,12 +522,17 @@ template static __device__ __forceinlin qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 - x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; const int kbxd = threadIdx.x % blocks_per_tile_x_row; - float * x_dmf = (float *) x_dm; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { @@ -435,19 +544,23 @@ template static __device__ __forceinlin const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; - x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = bxi->d; +#else + x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_sc); + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const float * x_dmf = (const float *) x_dm; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -457,70 +570,57 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a( for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - const int index_bx = i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0; - - int u[2*VDR_Q5_0_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_0) % WARP_SIZE]; - } - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE], + x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } template static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE - GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; typedef mma_int_C_I16J8 mma_C; - const float * x_df = (const float *) x_dm; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + WARP_SIZE*2; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - mma_A A; - float dA[mma_C::ne/2]; + mma_A A[ntx]; + float dA[ntx][mma_C::ne/2]; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0; + for (int n = 0; n < ntx; ++n) { + A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_0 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_0); - A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k]; - } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I; - dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0]; + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + k0/QI5_0]; + } } - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - mma_C C; +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { mma_B B; float dB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; + B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K); - B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -528,23 +628,34 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; } - C.mma_K8(A, B); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l]; + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l]; + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q5_1( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_sc); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI5_1; const int kqsx = threadIdx.x % QI5_1; @@ -568,15 +679,19 @@ template static __device__ __forceinlin qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0; - int qs1 = (ql >> 4) & 0x0F0F0F0F; qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; @@ -592,18 +707,23 @@ template static __device__ __forceinlin const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; - x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; +#ifdef INT8_MMA_AVAILABLE + x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + kbxd] = bxi->dm; +#else + x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_sc); + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -613,69 +733,57 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a( for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - const int index_bx = i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1; - - int u[2*VDR_Q5_1_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_1) % WARP_SIZE]; - } - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl - (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE], + x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } template static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE - GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; typedef mma_int_C_I16J8 mma_C; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - mma_A A; - half2 dmA[mma_C::ne/2]; + mma_A A[ntx]; + half2 dmA[ntx][mma_C::ne/2]; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1; + for (int n = 0; n < ntx; ++n) { + A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_1 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_1); - A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k]; - } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I; - dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1]; + dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + k0/QI5_1]; + } } - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - mma_C C; +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { mma_B B; half2 dsB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; + B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K); - B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -683,28 +791,38 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; } - C.mma_K8(A, B); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - const half2 dmA_dsB = dmA[l/2]*dsB[l%2]; - sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + for (int l = 0; l < mma_C::ne; ++l) { + const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2]; + sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q8_0( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_sc); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_tile + WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI8_0; const int kqsx = threadIdx.x % QI8_0; - float * x_dmf = (float *) x_dm; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { @@ -716,7 +834,11 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; - x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx); +#else + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx); +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; @@ -732,19 +854,23 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; - x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_sc); + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const float * x_dmf = (const float *) x_dm; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -755,7 +881,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int i = i0 + threadIdx.x; sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0], + (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + k0/QI8_1]); } } @@ -763,51 +889,48 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE - GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; typedef mma_int_C_I16J8 mma_C; - const float * x_df = (const float *) x_dm; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + WARP_SIZE; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - mma_A A; - float dA[mma_C::ne/2]; + mma_A A[ntx]; + float dA[ntx][mma_C::ne/2]; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y/ntx)*rows_per_warp; #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = k0 + mma_A::get_k(l); + for (int n = 0; n < ntx; ++n) { + A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); - A.x[l] = x_qs[i*(WARP_SIZE + 1) + k]; - } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); - dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0]; + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + } } - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - mma_C C; +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { mma_B B; float dB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = k0 + mma_B::get_k(l); + B.load(y_qs + j0*MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K); - B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -815,22 +938,34 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1]; } - C.mma_K8(A, B); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2]; + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2]; + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q2_K( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI2_K; const int kqsx = threadIdx.x % QI2_K; @@ -859,7 +994,11 @@ template static __device__ __forceinlin continue; } - x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; +#else + x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k; +#endif // INT8_MMA_AVAILABLE } const int sc_m = bxi->scales[kqsx]; @@ -870,15 +1009,21 @@ template static __device__ __forceinlin const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); #endif // FAST_FP16_AVAILABLE - x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik; +#ifdef INT8_MMA_AVAILABLE + x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + threadIdx.x] = x_dm_ik; +#else + x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; @@ -899,61 +1044,63 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; typedef mma_int_B_J8K4 mma_B; typedef mma_int_C_I16J8 mma_C; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[2]; - float dA[mma_C::ne/2][2]; - float mA[mma_C::ne/2][2]; + mma_A A[ntx][2]; + float dA[ntx][mma_C::ne/2][2]; + float mA[ntx][mma_C::ne/2][2]; #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int shift = 2*mma_A::get_k(l); + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + n*mma_A::I + mma_A::get_i(l); + const int shift = 2*mma_A::get_k(l); - A[0].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 0] >> shift) & 0x03030303; - A[1].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 1] >> shift) & 0x03030303; - } + A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 0] >> shift) & 0x03030303; + A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 1] >> shift) & 0x03030303; + } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); #pragma unroll - for (int kk = 0; kk < 2; ++kk) { - const float2 dm = __half22float2(x_dm[i*(WARP_SIZE + 1) + k0 + kk]); + for (int kdm = 0; kdm < 2; ++kdm) { + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + kdm]); - dA[l][kk] = dm.x; - mA[l][kk] = dm.y; + dA[n][l][kdm] = dm.x; + mA[n][l][kdm] = dm.y; + } } } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - mma_C Cd[2]; - mma_C Cm[2]; + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { mma_B B[2]; float dB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE; + B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K); - B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0]; - B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -961,9 +1108,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)]; } - Cd[0].mma_K4(A[0], B[0]); - Cd[1].mma_K4(A[1], B[1]); - + mma_C Cm[2]; mma_A A1; A1.x[0] = 0x01010101; A1.x[1] = 0x01010101; @@ -971,19 +1116,38 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( Cm[1].mma_K4(A1, B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_B::J)*mma_C::ne + l] += (Cd[0].x[l]*dA[l/2][0] + Cd[1].x[l]*dA[l/2][1] - Cm[0].x[l]*mA[l/2][0] - Cm[1].x[l]*mA[l/2][1])*dB[l%2]; + for (int n = 0; n < ntx; ++n) { + mma_C Cd[2]; + + Cd[0].mma_K4(A[n][0], B[0]); + Cd[1].mma_K4(A[n][1], B[1]); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += ( + Cd[0].x[l]*dA[n][l/2][0] + Cd[1].x[l]*dA[n][l/2][1] - Cm[0].x[l]*mA[n][l/2][0] - Cm[1].x[l]*mA[n][l/2][1])*dB[l%2]; + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q3_K( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); + int * x_sc = (int *) (x_df + WARP_SIZE/QI3_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); + int * x_sc = (int *) (x_df + txs.dm); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI3_K; const int kqsx = threadIdx.x % QI3_K; @@ -1015,13 +1179,16 @@ template static __device__ __forceinlin continue; } - x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2] = x_qs_k; +#else + x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k; +#endif // INT8_MMA_AVAILABLE } } const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; const int kbxd = threadIdx.x % blocks_per_tile_x_row; - float * x_dmf = (float *) x_dm; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { @@ -1033,7 +1200,11 @@ template static __device__ __forceinlin const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd; - x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + kbxd] = bxi->d; +#else + x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE } #pragma unroll @@ -1058,16 +1229,22 @@ template static __device__ __forceinlin const int sc = __vsubss4(sc_low | sc_high, 0x20202020); - x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = sc; +#ifdef INT8_MMA_AVAILABLE + x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/4)] = sc; +#else + x_sc[i*(WARP_SIZE/4) + i/4 + threadIdx.x % (WARP_SIZE/4)] = sc; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const float * x_df = (const float *) x_dm; + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * x_sc = (const int *) x_df + txs.dm; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; @@ -1093,69 +1270,72 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; typedef mma_int_B_J8K4 mma_B; typedef mma_int_C_I16J8 mma_C; - const float * x_df = (const float *) x_dm; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + WARP_SIZE*2; + const int * x_sc = (const int *) x_df + WARP_SIZE/QI3_K; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[2]; - int scA[mma_C::ne/2][2]; - float dA[mma_C::ne/2]; + mma_A A[ntx][2]; + int scA[ntx][mma_C::ne/2][2]; + float dA[ntx][mma_C::ne/2]; #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = QR3_K*k0 + mma_A::get_k(l); + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + n*mma_A::I + mma_A::get_i(l); + const int k = QR3_K*k0 + mma_A::get_k(l); - A[0].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F; - A[1].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F; - A[0].x[l] = __vsubss4(A[0].x[l], 0x04040404); - A[1].x[l] = __vsubss4(A[1].x[l], 0x04040404); - } + A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F; + A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F; + A[n][0].x[l] = __vsubss4(A[n][0].x[l], 0x04040404); + A[n][1].x[l] = __vsubss4(A[n][1].x[l], 0x04040404); + } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - const int kbx = k0 / QI3_K; - const int ky = (k0 % QI3_K) * QR3_K; - const int8_t * sc = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; + const int kbx = k0 / QI3_K; + const int ky = (k0 % QI3_K) * QR3_K; + const int8_t * sc = ((const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q3_K + kbx*4)) + ky/4; - scA[l][0] = sc[0]; - scA[l][1] = sc[1]; - } + scA[n][l][0] = sc[0]; + scA[n][l][1] = sc[1]; + } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - dA[l] = x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + k0/QI3_K]; + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/QI3_K]; + } } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - mma_C C[2]; + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { mma_B B[2]; float dB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE; + B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K); - B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0]; - B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -1163,23 +1343,37 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)]; } - C[0].mma_K4(A[0], B[0]); - C[1].mma_K4(A[1], B[1]); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C[2]; + C[0].mma_K4(A[n][0], B[0]); + C[1].mma_K4(A[n][1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*scA[l/2][0] + C[1].x[l]*scA[l/2][1])*dA[l/2]*dB[l%2]; + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += (C[0].x[l]*scA[n][l/2][0] + C[1].x[l]*scA[n][l/2][1])*dA[n][l/2]*dB[l%2]; + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q4_K( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); + int * x_sc = (int *) (x_dm + WARP_SIZE/QI4_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); + int * x_sc = (int *) (x_dm + txs.dm); +#endif // INT8_MMA_AVAILABLE const int kbx = 0; // threadIdx.x / QI4_K const int kqsx = threadIdx.x; // threadIdx.x % QI4_K @@ -1194,7 +1388,11 @@ template static __device__ __forceinlin const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx; - x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); +#else + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 @@ -1210,7 +1408,11 @@ template static __device__ __forceinlin const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd; - x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; +#ifdef INT8_MMA_AVAILABLE + x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + kbxd] = bxi->dm; +#else + x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K + kbxd] = bxi->dm; +#endif // INT8_MMA_AVAILABLE } #pragma unroll @@ -1231,15 +1433,22 @@ template static __device__ __forceinlin int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; +#ifdef INT8_MMA_AVAILABLE + x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + ksc] = scales8; +#else + x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * x_sc = (const int *) x_dm + txs.dm; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; @@ -1262,71 +1471,79 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; typedef mma_int_C_I16J8 mma_C; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE; + const int * x_sc = (const int *) x_dm + WARP_SIZE/QI4_K; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + + mma_A A[ntx][2]; + int scA[ntx][mma_C::ne/2][2]; + int mA[ntx][mma_C::ne/2][2]; + half2 dmA[ntx][mma_C::ne/2]; - mma_A A[2]; - int scA[mma_C::ne/2][2]; - int mA[mma_C::ne/2][2]; - half2 dmA[mma_C::ne/2]; #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) { + for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = k0 + mma_A::get_k(l); + for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 8) { + A[n][kvdr/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0, MMQ_MMA_TILE_X_K_Q4_K); - A[kvdr/4].x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F; +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + A[n][kvdr/4 + 1].x[l] = (A[n][kvdr/4 + 0].x[l] >> 4) & 0x0F0F0F0F; + A[n][kvdr/4 + 0].x[l] &= 0x0F0F0F0F; + } } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) { +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8); - const uint8_t * m = sc + 8; + const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + k0/16]) + 2 * ((k0 % 16) / 8); + const uint8_t * m = sc + 8; - scA[l][kvdr/4] = sc[kvdr/4]; - mA[l][kvdr/4] = m[kvdr/4]; + scA[n][l][kvdr/4] = sc[kvdr/4]; + mA[n][l][kvdr/4] = m[kvdr/4]; + } } - } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); - dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K]; + dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + k0/QI4_K]; + } } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - float tmpd[mma_C::ne] = {0.0f}; - float tmpm[mma_C::ne] = {0.0f}; + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + float tmpd[ntx][mma_C::ne] = {{0.0f}}; + float tmpm[ntx][mma_C::ne] = {{0.0f}}; #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) { - mma_C C; + for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) { mma_B B; half2 dsB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE; + B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K); - B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -1334,29 +1551,46 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)]; } - C.mma_K8(A[kvdr/4], B); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n][kvdr/4], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]); - tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]); + for (int l = 0; l < mma_C::ne; ++l) { + tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) * __low2float(dsB[l%2]); + tmpm[n][l] += mA[n][l/2][kvdr/4] * __high2float(dsB[l%2]); + } } } #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l]; + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l]; + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q5_K( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2); + int * x_sc = (int *) (x_dm + WARP_SIZE/QI5_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); + int * x_sc = (int *) (x_dm + txs.dm); +#endif // INT8_MMA_AVAILABLE const int kbx = 0; // threadIdx.x / QI5_K const int kqsx = threadIdx.x; // threadIdx.x % QI5_K @@ -1383,8 +1617,13 @@ template static __device__ __forceinlin const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0; const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4); - x_qs[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; - x_qs[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq0] = ql0 | qh0; + x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq1] = ql1 | qh1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0; + x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1; +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 @@ -1400,7 +1639,11 @@ template static __device__ __forceinlin const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd; - x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; +#ifdef INT8_MMA_AVAILABLE + x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + kbxd] = bxi->dm; +#else + x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + kbxd] = bxi->dm; +#endif // INT8_MMA_AVAILABLE } #pragma unroll @@ -1421,17 +1664,24 @@ template static __device__ __forceinlin int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; +#ifdef INT8_MMA_AVAILABLE + x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + ksc] = scales8; +#else + x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * x_sc = (const int *) x_dm + txs.dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -1452,71 +1702,70 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; typedef mma_int_C_I16J8 mma_C; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2; + const int * x_sc = (const int *) x_dm + WARP_SIZE/QI5_K; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + + mma_A A[ntx][2]; + int scA[ntx][mma_C::ne/2][2]; + int mA[ntx][mma_C::ne/2][2]; + half2 dmA[ntx][mma_C::ne/2]; - mma_A A[2]; - int scA[mma_C::ne/2][2]; - int mA[mma_C::ne/2][2]; - half2 dmA[mma_C::ne/2]; #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) { + for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l); - - A[kvdr/4].x[l] = x_qs[i*(QR5_K*WARP_SIZE + 1) + k]; - } + for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) { + A[n][kvdr/4].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + (QR5_K*k0 + QR5_K*kvdr), MMQ_MMA_TILE_X_K_Q5_K); #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8); - const uint8_t * m = sc + 8; + const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + k0/16]) + 2 * ((k0 % 16) / 8); + const uint8_t * m = sc + 8; - scA[l][kvdr/4] = sc[kvdr/4]; - mA[l][kvdr/4] = m[kvdr/4]; + scA[n][l][kvdr/4] = sc[kvdr/4]; + mA[n][l][kvdr/4] = m[kvdr/4]; + } } - } -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + #pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K]; + dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + k0/QI5_K]; + } } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - float tmpd[mma_C::ne] = {0.0f}; - float tmpm[mma_C::ne] = {0.0f}; + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + float tmpd[ntx][mma_C::ne] = {{0.0f}}; + float tmpm[ntx][mma_C::ne] = {{0.0f}}; #pragma unroll for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) { - mma_C C; mma_B B; half2 dsB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE; + B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K); - B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -1524,29 +1773,46 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)]; } - C.mma_K8(A[kvdr/4], B); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n][kvdr/4], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]); - tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]); + for (int l = 0; l < mma_C::ne; ++l) { + tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) * __low2float(dsB[l%2]); + tmpm[n][l] += mA[n][l/2][kvdr/4] * __high2float(dsB[l%2]); + } } } #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l]; + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l]; + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q6_K( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); + int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); + int * x_sc = (int *) (x_df + txs.dm); +#endif // INT8_MMA_AVAILABLE const int kbx = 0; // threadIdx.x / QI6_K const int kqsx = threadIdx.x; // threadIdx.x % QI6_K @@ -1573,13 +1839,17 @@ template static __device__ __forceinlin const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0; const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2); - x_qs[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); - x_qs[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#else + x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 - float * x_dmf = (float *) x_dm; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { @@ -1591,7 +1861,11 @@ template static __device__ __forceinlin const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd; - x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d; +#else + x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE } #pragma unroll @@ -1604,18 +1878,24 @@ template static __device__ __forceinlin const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4; - x_sc[i * (WARP_SIZE/8) + i / 8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8)); +#ifdef INT8_MMA_AVAILABLE + x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8)); +#else + x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8)); +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const float * x_dmf = (const float *) x_dm; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * x_sc = (const int *) x_df + txs.dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -1629,80 +1909,77 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq( &x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc, - x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]); + x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]); } } } template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; typedef mma_int_B_J8K4 mma_B; typedef mma_int_C_I16J8 mma_C; - const float * x_df = (const float *) x_dm; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + WARP_SIZE*2; + const int * x_sc = (const int *) x_df + WARP_SIZE/QI6_K; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + + mma_A A[ntx][4]; + int scA[ntx][mma_C::ne/2][4]; + float dA[ntx][mma_C::ne/2]; - mma_A A[4]; - int scA[mma_C::ne/2][4]; - float dA[mma_C::ne/2]; #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) { + for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l); - - A[kvdr/2 + 0].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + 0]; - A[kvdr/2 + 1].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K]; - } + for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) { + A[n][kvdr/2 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + 0), MMQ_MMA_TILE_X_K_Q6_K); + A[n][kvdr/2 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K); #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]); + const int8_t * sc = ((const int8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/8]); - scA[l][kvdr/2 + 0] = sc[kvdr/2 + 0]; - scA[l][kvdr/2 + 1] = sc[kvdr/2 + 1]; + scA[n][l][kvdr/2 + 0] = sc[kvdr/2 + 0]; + scA[n][l][kvdr/2 + 1] = sc[kvdr/2 + 1]; + } } - } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - dA[l] = x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + k0/QI6_K]; + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K + k0/QI6_K]; + } } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - float tmp[mma_C::ne] = {0.0f}; + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + float tmp[ntx][mma_C::ne] = {{0.0f}}; #pragma unroll for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) { - mma_C C[2]; mma_B B[2]; float dB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE; + const int k0B = (2*k0 + 2*kvdr) % WARP_SIZE; + B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k0B, MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k0B, MMQ_TILE_Y_K); - B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0]; - B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -1710,22 +1987,29 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)]; } - C[0].mma_K4(A[kvdr/2 + 0], B[0]); - C[1].mma_K4(A[kvdr/2 + 1], B[1]); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C[2]; + C[0].mma_K4(A[n][kvdr/2 + 0], B[0]); + C[1].mma_K4(A[n][kvdr/2 + 1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - tmp[l] += (C[0].x[l]*scA[l/2][kvdr/2 + 0] + C[1].x[l]*scA[l/2][kvdr/2 + 1])*dB[l%2]; + for (int l = 0; l < mma_C::ne; ++l) { + tmp[n][l] += (C[0].x[l]*scA[n][l/2][kvdr/2 + 0] + C[1].x[l]*scA[n][l/2][kvdr/2 + 1])*dB[l%2]; + } } } #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2]; + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2]; + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } @@ -1761,28 +2045,35 @@ static __device__ __forceinline__ void mmq_write_back_mma( typedef mma_int_C_I16J8 mma_C; - const int i0 = threadIdx.y*mma_C::I; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I); #ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); #endif // INT8_MMA_AVAILABLE #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) { + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l); - if (j > j_max) { - continue; - } + if (j > j_max) { + continue; + } - const int i = i0 + mma_C::get_i(l); + const int i = i0 + n*mma_C::I + mma_C::get_i(l); - if (need_check && i > i_max) { - continue; - } + if (need_check && i > i_max) { + continue; + } - dst[j*stride + i] = sum[(j0/mma_C::J)*mma_C::ne + l]; + dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l]; + } } } } @@ -1910,6 +2201,10 @@ static __device__ void mul_mat_q_process_tile( constexpr int vdr = mmq_type_traits::vdr; constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; + extern __shared__ char data_mul_mat_q[]; + int * tile_y = (int *) data_mul_mat_q; + int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE); + #ifdef INT8_MMA_AVAILABLE constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; constexpr mmq_write_back_t write_back = mmq_write_back_mma; @@ -1918,14 +2213,6 @@ static __device__ void mul_mat_q_process_tile( constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; #endif // INT8_MMA_AVAILABLE - constexpr tile_x_sizes txs = get_tile_x_sizes_device(type); - - extern __shared__ char data_mul_mat_q[]; - int * tile_x_qs = (int *) data_mul_mat_q; - half2 * tile_x_dm = (half2 *) (tile_x_qs + txs.qs); - int * tile_x_sc = (int *) (tile_x_dm + txs.dm); - int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)] - constexpr int blocks_per_warp = WARP_SIZE / qi; float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; @@ -1937,7 +2224,7 @@ static __device__ void mul_mat_q_process_tile( for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) { - load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*it*mmq_y + kb0, tile_x_max_i, stride01); + load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01); #pragma unroll for (int kr = 0; kr < qr; ++kr) { @@ -1953,7 +2240,7 @@ static __device__ void mul_mat_q_process_tile( // #pragma unroll // unrolling this loop causes too much register pressure for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) { - vec_dot(tile_x_qs, tile_x_dm, tile_x_sc, tile_y, sum, k0); + vec_dot(tile_x, tile_y, sum, k0); } __syncthreads(); @@ -1987,7 +2274,7 @@ static __global__ void mul_mat_q( const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) { // Skip unused template specializations for faster compilation: - if (mmq_x > get_mmq_x_max_device()) { + if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { NO_DEVICE_CODE; return; } @@ -2139,11 +2426,12 @@ struct mmq_args { int64_t ne0; }; -static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) { - const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y); - - const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); - const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2); +template +static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) { + const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); + const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); + const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const int shmem_y = mmq_x*sizeof(block_q8_1_mmq); return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int)); } @@ -2156,7 +2444,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1); - const int shmem = mmq_get_shmem(type, mmq_x, mmq_y); + const int shmem = mmq_get_shmem(mmq_x, mmq_y, cc); #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -2225,12 +2513,17 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda int nparts_best = INT_MAX; for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) { + const int granularity = mmq_get_granularity_host(mmq_x, cc); + + if (mmq_x % granularity != 0 || mmq_get_shmem(mmq_x, mmq_y, cc) > smpbo) { + continue; + } + const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x; const int nwaves_xy_tiling = ntiles_x*block_num_y; - const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling; - if (nparts < nparts_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) { + if (nparts < nparts_best) { mmq_x_best = mmq_x; nparts_best = nparts; } @@ -2314,4 +2607,4 @@ void ggml_cuda_op_mul_mat_q( const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); -bool ggml_cuda_supports_mmq(enum ggml_type type); +bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11); diff --git a/ggml-cuda/mmvq.cuh b/ggml-cuda/mmvq.cuh index 88c42c4b7a8fb..d9e42fdd6d16c 100644 --- a/ggml-cuda/mmvq.cuh +++ b/ggml-cuda/mmvq.cuh @@ -1,5 +1,7 @@ #include "common.cuh" +#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. + void ggml_cuda_op_mul_mat_vec_q( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index e5ddf4a346c36..db045336f1edb 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -4620,7 +4620,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst); - } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 100594b303526..222a2d137b08f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -75,6 +75,11 @@ class Rope: SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier" + class Split: + LLM_KV_SPLIT_NO = "split.no" + LLM_KV_SPLIT_COUNT = "split.count" + LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count" + class SSM: CONV_KERNEL = "{arch}.ssm.conv_kernel" INNER_SIZE = "{arch}.ssm.inner_size" diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py index e48bc00c388c8..20432bd258458 100644 --- a/gguf-py/gguf/gguf_reader.py +++ b/gguf-py/gguf/gguf_reader.py @@ -69,6 +69,7 @@ class GGUFReader: # I - same as host, S - swapped byte_order: Literal['I'] | Literal['S'] = 'I' alignment: int = GGUF_DEFAULT_ALIGNMENT + data_offset: int # Note: Internal helper, API may change. gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = { @@ -88,9 +89,13 @@ class GGUFReader: def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'): self.data = np.memmap(path, mode = mode) offs = 0 + + # Check for GGUF magic if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC: raise ValueError('GGUF magic invalid') offs += 4 + + # Check GGUF version temp_version = self._get(offs, np.uint32) if temp_version[0] & 65535 == 0: # If we get 0 here that means it's (probably) a GGUF file created for @@ -103,12 +108,16 @@ def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r self.fields: OrderedDict[str, ReaderField] = OrderedDict() self.tensors: list[ReaderTensor] = [] offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32])) + + # Check tensor count and kv count temp_counts = self._get(offs, np.uint64, 2) offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64])) offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64])) tensor_count, kv_count = temp_counts offs = self._build_fields(offs, kv_count) - offs, tensors_fields = self._build_tensors_fields(offs, tensor_count) + + # Build Tensor Info Fields + offs, tensors_fields = self._build_tensor_info(offs, tensor_count) new_align = self.fields.get('general.alignment') if new_align is not None: if new_align.types != [GGUFValueType.UINT32]: @@ -117,6 +126,7 @@ def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r padding = offs % self.alignment if padding != 0: offs += self.alignment - padding + self.data_offset = offs self._build_tensors(offs, tensors_fields) _DT = TypeVar('_DT', bound = npt.DTypeLike) @@ -193,18 +203,29 @@ def _get_field_parts( # We can't deal with this one. raise ValueError('Unknown/unhandled field type {gtype}') - def _get_tensor(self, orig_offs: int) -> ReaderField: + def _get_tensor_info_field(self, orig_offs: int) -> ReaderField: offs = orig_offs + + # Get Tensor Name name_len, name_data = self._get_str(offs) offs += int(name_len.nbytes + name_data.nbytes) + + # Get Tensor Dimensions Count n_dims = self._get(offs, np.uint32) offs += int(n_dims.nbytes) + + # Get Tensor Dimension Array dims = self._get(offs, np.uint64, n_dims[0]) offs += int(dims.nbytes) + + # Get Tensor Encoding Scheme Type raw_dtype = self._get(offs, np.uint32) offs += int(raw_dtype.nbytes) + + # Get Tensor Offset offset_tensor = self._get(offs, np.uint64) offs += int(offset_tensor.nbytes) + return ReaderField( orig_offs, str(bytes(name_data), encoding = 'utf-8'), @@ -233,10 +254,10 @@ def _build_fields(self, offs: int, count: int) -> int: offs += field_size return offs - def _build_tensors_fields(self, offs: int, count: int) -> tuple[int, list[ReaderField]]: + def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderField]]: tensor_fields = [] for _ in range(count): - field = self._get_tensor(offs) + field = self._get_tensor_info_field(offs) offs += sum(int(part.nbytes) for part in field.parts) tensor_fields.append(field) return offs, tensor_fields diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 3b841a62583a3..9869f6fe3445a 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -7,6 +7,7 @@ import tempfile from dataclasses import dataclass from enum import Enum, auto +from pathlib import Path from io import BufferedWriter from typing import IO, Any, Sequence, Mapping from string import ascii_letters, digits @@ -31,6 +32,9 @@ logger = logging.getLogger(__name__) +SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf" + + @dataclass class TensorInfo: shape: Sequence[int] @@ -55,11 +59,11 @@ class WriterState(Enum): class GGUFWriter: - fout: BufferedWriter | None - path: os.PathLike[str] | str | None + fout: list[BufferedWriter] | None + path: Path | None temp_file: tempfile.SpooledTemporaryFile[bytes] | None - tensors: dict[str, TensorInfo] - kv_data: dict[str, GGUFValue] + tensors: list[dict[str, TensorInfo]] + kv_data: list[dict[str, GGUFValue]] state: WriterState _simple_value_packing = { GGUFValueType.UINT8: "B", @@ -76,26 +80,38 @@ class GGUFWriter: } def __init__( - self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, - endianess: GGUFEndian = GGUFEndian.LITTLE, + self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE, + split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False ): self.fout = None - self.path = path + self.path = Path(path) if path else None self.arch = arch self.endianess = endianess self.data_alignment = GGUF_DEFAULT_ALIGNMENT self.use_temp_file = use_temp_file self.temp_file = None - self.tensors = dict() - self.kv_data = dict() + self.tensors = [{}] + self.kv_data = [{}] + self.split_max_tensors = split_max_tensors + self.split_max_size = split_max_size + self.dry_run = dry_run + self.small_first_shard = small_first_shard logger.info("gguf: This GGUF file is for {0} Endian only".format( "Big" if self.endianess == GGUFEndian.BIG else "Little", )) self.state = WriterState.NO_FILE + if self.small_first_shard: + self.tensors.append({}) + self.add_architecture() - def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None: + def format_shard_names(self, path: Path) -> list[Path]: + if len(self.tensors) == 1: + return [path] + return [path.with_name(SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors))) for i in range(len(self.tensors))] + + def open_output_file(self, path: Path | None = None) -> None: if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path): # allow calling this multiple times as long as the path is the same return @@ -106,22 +122,58 @@ def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None: self.path = path if self.path is not None: - if self.fout is not None: - self.fout.close() - self.fout = open(self.path, "wb") + filenames = self.print_plan() + self.fout = [open(filename, "wb") for filename in filenames] self.state = WriterState.EMPTY - def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None: + def print_plan(self) -> list[Path]: + logger.info("Writing the following files:") + assert self.path is not None + filenames = self.format_shard_names(self.path) + assert len(filenames) == len(self.tensors) + for name, tensors in zip(filenames, self.tensors): + logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}") + + if self.dry_run: + logger.info("Dry run, not writing files") + exit() + + return filenames + + def add_shard_kv_data(self) -> None: + if len(self.tensors) == 1: + return + + total_tensors = sum(len(t) for t in self.tensors) + assert self.fout is not None + total_splits = len(self.fout) + self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits)) + for i, kv_data in enumerate(self.kv_data): + kv_data[Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16) + kv_data[Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16) + kv_data[Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32) + + def write_header_to_file(self, path: Path | None = None) -> None: + if len(self.tensors) == 1 and (self.split_max_tensors != 0 or self.split_max_size != 0): + logger.warning("Model fails split requirements, not splitting") + self.open_output_file(path) if self.state is not WriterState.EMPTY: raise ValueError(f'Expected output file to be empty, got {self.state}') - self._write_packed(" None: @@ -129,13 +181,15 @@ def write_kv_data_to_file(self) -> None: raise ValueError(f'Expected output file to contain the header, got {self.state}') assert self.fout is not None - kv_data = bytearray() + for fout, kv_data in zip(self.fout, self.kv_data): + kv_bytes = bytearray() - for key, val in self.kv_data.items(): - kv_data += self._pack_val(key, GGUFValueType.STRING, add_vtype=False) - kv_data += self._pack_val(val.value, val.type, add_vtype=True) + for key, val in kv_data.items(): + kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False) + kv_bytes += self._pack_val(val.value, val.type, add_vtype=True) + + fout.write(kv_bytes) - self.fout.write(kv_data) self.flush() self.state = WriterState.KV_DATA @@ -144,28 +198,29 @@ def write_ti_data_to_file(self) -> None: raise ValueError(f'Expected output file to contain KV data, got {self.state}') assert self.fout is not None - ti_data = bytearray() - offset_tensor = 0 - - for name, ti in self.tensors.items(): - ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False) - n_dims = len(ti.shape) - ti_data += self._pack("I", n_dims) - for i in range(n_dims): - ti_data += self._pack("Q", ti.shape[n_dims - 1 - i]) - ti_data += self._pack("I", ti.dtype) - ti_data += self._pack("Q", offset_tensor) - offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment) - - self.fout.write(ti_data) - self.flush() + for fout, tensors in zip(self.fout, self.tensors): + ti_data = bytearray() + offset_tensor = 0 + + for name, ti in tensors.items(): + ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False) + n_dims = len(ti.shape) + ti_data += self._pack("I", n_dims) + for j in range(n_dims): + ti_data += self._pack("Q", ti.shape[n_dims - 1 - j]) + ti_data += self._pack("I", ti.dtype) + ti_data += self._pack("Q", offset_tensor) + offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment) + + fout.write(ti_data) + fout.flush() self.state = WriterState.TI_DATA def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None: - if key in self.kv_data: + if any(key in kv_data for kv_data in self.kv_data): raise ValueError(f'Duplicated key name {key!r}') - self.kv_data[key] = GGUFValue(value=val, type=vtype) + self.kv_data[0][key] = GGUFValue(value=val, type=vtype) def add_uint8(self, key: str, val: int) -> None: self.add_key_value(key,val, GGUFValueType.UINT8) @@ -206,9 +261,6 @@ def add_string(self, key: str, val: str) -> None: self.add_key_value(key, val, GGUFValueType.STRING) def add_array(self, key: str, val: Sequence[Any]) -> None: - if not isinstance(val, Sequence): - raise ValueError("Value must be a sequence for array type") - self.add_key_value(key, val, GGUFValueType.ARRAY) @staticmethod @@ -222,7 +274,7 @@ def add_tensor_info( if self.state is not WriterState.NO_FILE: raise ValueError(f'Expected output file to be not yet opened, got {self.state}') - if name in self.tensors: + if any(name in tensors for tensors in self.tensors): raise ValueError(f'Duplicated tensor name {name!r}') if raw_dtype is None: @@ -247,7 +299,18 @@ def add_tensor_info( if tensor_dtype == np.uint8: tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype) - self.tensors[name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes) + # make sure there is at least one tensor before splitting + if len(self.tensors[-1]) > 0: + if ( # split when over tensor limit + self.split_max_tensors != 0 + and len(self.tensors[-1]) >= self.split_max_tensors + ) or ( # split when over size limit + self.split_max_size != 0 + and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size + ): + self.tensors.append({}) + + self.tensors[-1][name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes) def add_tensor( self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, @@ -264,7 +327,7 @@ def add_tensor( self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype) if self.temp_file is None: - self.tensors[name].tensor = tensor + self.tensors[-1][name].tensor = tensor return tensor.tofile(self.temp_file) @@ -282,9 +345,24 @@ def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: if self.endianess == GGUFEndian.BIG: tensor.byteswap(inplace=True) - self.write_padding(self.fout, self.fout.tell()) - tensor.tofile(self.fout) - self.write_padding(self.fout, tensor.nbytes) + + file_id = -1 + for i, tensors in enumerate(self.tensors): + if len(tensors) > 0: + file_id = i + break + + fout = self.fout[file_id] + + # pop the first tensor info + # TODO: cleaner way to get the first key + first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0] + ti = self.tensors[file_id].pop(first_tensor_name) + assert ti.nbytes == tensor.nbytes + + self.write_padding(fout, fout.tell()) + tensor.tofile(fout) + self.write_padding(fout, tensor.nbytes) self.state = WriterState.WEIGHTS @@ -293,31 +371,43 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: assert self.fout is not None - self.write_padding(self.fout, self.fout.tell()) + for fout in self.fout: + self.write_padding(fout, fout.tell()) if self.temp_file is None: + shard_bar = None bar = None if progress: from tqdm import tqdm - total_bytes = sum(t.nbytes for t in self.tensors.values()) + total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values()) + if len(self.fout) > 1: + shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True) bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) - # relying on the fact that Python dicts preserve insertion order (since 3.7) - for ti in self.tensors.values(): - assert ti.tensor is not None # can only iterate once over the tensors - assert ti.tensor.nbytes == ti.nbytes - ti.tensor.tofile(self.fout) - if bar is not None: - bar.update(ti.nbytes) - self.write_padding(self.fout, ti.nbytes) - ti.tensor = None + for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)): + if shard_bar is not None: + shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})") + total = sum(ti.nbytes for ti in tensors.values()) + shard_bar.reset(total=(total if total > 0 else None)) + + # relying on the fact that Python dicts preserve insertion order (since 3.7) + for ti in tensors.values(): + assert ti.tensor is not None # can only iterate once over the tensors + assert ti.tensor.nbytes == ti.nbytes + ti.tensor.tofile(fout) + if shard_bar is not None: + shard_bar.update(ti.nbytes) + if bar is not None: + bar.update(ti.nbytes) + self.write_padding(fout, ti.nbytes) + ti.tensor = None else: self.temp_file.seek(0) - shutil.copyfileobj(self.temp_file, self.fout) + shutil.copyfileobj(self.temp_file, self.fout[0 if not self.small_first_shard else 1]) self.flush() self.temp_file.close() @@ -325,11 +415,13 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: def flush(self) -> None: assert self.fout is not None - self.fout.flush() + for fout in self.fout: + fout.flush() def close(self) -> None: if self.fout is not None: - self.fout.close() + for fout in self.fout: + fout.close() self.fout = None def add_architecture(self) -> None: @@ -626,6 +718,13 @@ def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool) -> bytes: return kv_data - def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None: - assert self.fout is not None - self.fout.write(self._pack(fmt, value, skip_pack_prefix)) + @staticmethod + def format_n_bytes_to_str(num: int) -> str: + if num == 0: + return "negligible - metadata only" + fnum = float(num) + for unit in ("", "K", "M", "G"): + if abs(fnum) < 1000.0: + return f"{fnum:3.1f}{unit}" + fnum /= 1000.0 + return f"{fnum:.1f}T - over 1TB, split recommended" diff --git a/gguf-py/scripts/gguf-dump.py b/gguf-py/scripts/gguf-dump.py index 92d14d6cd0a69..a73ca2776d32b 100755 --- a/gguf-py/scripts/gguf-dump.py +++ b/gguf-py/scripts/gguf-dump.py @@ -208,7 +208,9 @@ def translate_tensor_name(name): 'ssm_d': 'State space model skip connection', 'ssm_dt': 'State space model time step', 'ssm_out': 'State space model output projection', - 'blk': 'Block' + 'blk': 'Block', + 'enc': 'Encoder', + 'dec': 'Decoder', } expanded_words = [] @@ -291,6 +293,10 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None tensor_group_name = "base" if tensor_components[0] == 'blk': tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}" + elif tensor_components[0] in ['enc', 'dec'] and tensor_components[1] == 'blk': + tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}.{tensor_components[2]}" + elif tensor_components[0] in ['enc', 'dec']: + tensor_group_name = f"{tensor_components[0]}" # Check if new Tensor Group if tensor_group_name not in tensor_groups: @@ -313,6 +319,27 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None markdown_content += "\n" + markdown_content += "### Tensor Data Offset\n" + markdown_content += '\n' + markdown_content += 'This table contains the offset and data segment relative to start of file\n' + markdown_content += '\n' + + tensor_mapping_table: list[dict[str, str | int]] = [] + for key, tensor in enumerate(reader.tensors): + data_offset_pretty = '{0:#16x}'.format(tensor.data_offset) + data_size_pretty = '{0:#16x}'.format(tensor.n_bytes) + tensor_mapping_table.append({"t_id":key, "layer_name":tensor.name, "data_offset":data_offset_pretty, "data_size":data_size_pretty}) + + tensors_mapping_table_header_map = [ + {'key_name':'t_id', 'header_name':'T_ID', 'align':'right'}, + {'key_name':'layer_name', 'header_name':'Tensor Layer Name', 'align':'left'}, + {'key_name':'data_offset', 'header_name':'Data Offset (B)', 'align':'right'}, + {'key_name':'data_size', 'header_name':'Data Size (B)', 'align':'right'}, + ] + + markdown_content += markdown_table_with_alignment_support(tensors_mapping_table_header_map, tensor_mapping_table) + markdown_content += "\n" + for group in tensor_prefix_order: tensors = tensor_groups[group] group_elements = sum(tensor.n_elements for tensor in tensors) @@ -364,6 +391,8 @@ def main() -> None: parser.add_argument("--no-tensors", action="store_true", help="Don't dump tensor metadata") parser.add_argument("--json", action="store_true", help="Produce JSON output") parser.add_argument("--json-array", action="store_true", help="Include full array values in JSON output (long)") + parser.add_argument("--data-offset", action="store_true", help="Start of data offset") + parser.add_argument("--data-alignment", action="store_true", help="Data alignment applied globally to data field") parser.add_argument("--markdown", action="store_true", help="Produce markdown output") parser.add_argument("--verbose", action="store_true", help="increase output verbosity") @@ -371,7 +400,7 @@ def main() -> None: logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) - if not args.json and not args.markdown: + if not args.json and not args.markdown and not args.data_offset and not args.data_alignment: logger.info(f'* Loading: {args.model}') reader = GGUFReader(args.model, 'r') @@ -380,6 +409,10 @@ def main() -> None: dump_metadata_json(reader, args) elif args.markdown: dump_markdown_metadata(reader, args) + elif args.data_offset: + print(reader.data_offset) # noqa: NP100 + elif args.data_alignment: + print(reader.alignment) # noqa: NP100 else: dump_metadata(reader, args) diff --git a/llama.cpp b/llama.cpp index 32683f2a6b34f..2960f198385b7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -19363,10 +19363,10 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|im_start|>assistant\n"; } - } else if (tmpl == "llama2" || tmpl.find("[INST]") != std::string::npos) { + } else if (tmpl == "llama2" || tmpl == "mistral" || tmpl.find("[INST]") != std::string::npos) { // llama2 template and its variants // [variant] support system message - bool support_system_message = tmpl.find("<>") != std::string::npos; + bool support_system_message = tmpl.find("<>") != std::string::npos || tmpl == "mistral"; // [variant] space before + after response bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos; // [variant] add BOS inside history diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index cef9a650bdfdf..d19ba8633e8c2 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -7,6 +7,7 @@ #include #include "llama.h" +#include "common.h" int main(void) { llama_chat_message conversation[] = { @@ -119,5 +120,24 @@ int main(void) { std::cout << output << "\n-------------------------\n"; assert(output == expected); } + + // test llama_chat_format_single + std::cout << "\n\n=== llama_chat_format_single ===\n\n"; + std::vector chat2; + chat2.push_back({"system", "You are a helpful assistant"}); + chat2.push_back({"user", "Hello"}); + chat2.push_back({"assistant", "I am assistant"}); + llama_chat_msg new_msg{"user", "How are you"}; + + auto fmt_single = [&](std::string tmpl) { + auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true); + std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n"; + return output; + }; + assert(fmt_single("chatml") == "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n"); + assert(fmt_single("llama2") == "[INST] How are you [/INST]"); + assert(fmt_single("gemma") == "user\nHow are you\nmodel\n"); + assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); + return 0; }