From 132b50458fc75135cd6135fa8ac65ee107669ae2 Mon Sep 17 00:00:00 2001 From: A Ghorbani Date: Sun, 21 Jul 2024 20:09:45 +0200 Subject: [PATCH 1/3] fix: add build for armv82 and v84 and sync llama.cpp --- android/src/main/CMakeLists.txt | 41 +- .../main/java/com/rnllama/LlamaContext.java | 53 +- cpp/common.cpp | 1979 ++--- cpp/common.h | 265 +- cpp/ggml-aarch64.c | 2193 ++++++ cpp/ggml-aarch64.h | 39 + cpp/ggml-alloc.c | 109 +- cpp/ggml-backend-impl.h | 28 +- cpp/ggml-backend.c | 276 +- cpp/ggml-backend.h | 39 +- cpp/ggml-common.h | 42 +- cpp/ggml-impl.h | 6 +- cpp/ggml-metal.h | 3 +- cpp/ggml-metal.m | 121 +- cpp/ggml-metal.metal | 6520 ++++++++++++++++ cpp/ggml-quants.c | 2366 +++--- cpp/ggml-quants.h | 39 +- cpp/ggml.c | 2457 ++---- cpp/ggml.h | 127 +- cpp/grammar-parser.cpp | 149 +- cpp/json-schema-to-grammar.cpp | 411 +- cpp/llama.cpp | 6936 ++++++++++++----- cpp/llama.h | 171 +- cpp/log.h | 2 +- cpp/rn-llama.hpp | 10 +- cpp/sampling.cpp | 23 +- cpp/sgemm.cpp | 45 +- cpp/sgemm.h | 2 +- cpp/unicode-data.cpp | 1653 ++-- cpp/unicode.cpp | 76 +- cpp/unicode.h | 3 +- llama.cpp | 2 +- scripts/bootstrap.sh | 52 +- scripts/common.cpp.patch | 2 +- scripts/ggml-metal.m.patch | 2 +- scripts/llama.cpp.patch | 4 +- 36 files changed, 19432 insertions(+), 6814 deletions(-) create mode 100644 cpp/ggml-aarch64.c create mode 100644 cpp/ggml-aarch64.h create mode 100644 cpp/ggml-metal.metal diff --git a/android/src/main/CMakeLists.txt b/android/src/main/CMakeLists.txt index 483f320e..f0ae959c 100644 --- a/android/src/main/CMakeLists.txt +++ b/android/src/main/CMakeLists.txt @@ -22,13 +22,14 @@ set( ${RNLLAMA_LIB_DIR}/unicode.cpp ${RNLLAMA_LIB_DIR}/llama.cpp ${RNLLAMA_LIB_DIR}/sgemm.cpp + ${RNLLAMA_LIB_DIR}/ggml-aarch64.c ${RNLLAMA_LIB_DIR}/rn-llama.hpp ${CMAKE_SOURCE_DIR}/jni.cpp ) find_library(LOG_LIB log) -function(build_library target_name) +function(build_library target_name cpu_flags) add_library( ${target_name} SHARED @@ -37,32 +38,34 @@ function(build_library target_name) target_link_libraries(${target_name} ${LOG_LIB} android) - target_compile_options(${target_name} PRIVATE -pthread) - - if (${target_name} STREQUAL "rnllama_v8fp16_va") - target_compile_options(${target_name} PRIVATE -march=armv8.4-a+fp16+dotprod) - endif () + target_compile_options(${target_name} PRIVATE -pthread ${cpu_flags}) if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") target_compile_options(${target_name} PRIVATE -DRNLLAMA_ANDROID_ENABLE_LOGGING) endif () - # NOTE: If you want to debug the native code, you can uncomment if and endif - # if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug") - - target_compile_options(${target_name} PRIVATE -O3 -DNDEBUG) - target_compile_options(${target_name} PRIVATE -fvisibility=hidden -fvisibility-inlines-hidden) - target_compile_options(${target_name} PRIVATE -ffunction-sections -fdata-sections) + if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug") + target_compile_options(${target_name} PRIVATE -O3 -DNDEBUG) + target_compile_options(${target_name} PRIVATE -fvisibility=hidden -fvisibility-inlines-hidden) + target_compile_options(${target_name} PRIVATE -ffunction-sections -fdata-sections) - target_link_options(${target_name} PRIVATE -Wl,--gc-sections) - target_link_options(${target_name} PRIVATE -Wl,--exclude-libs,ALL) - target_link_options(${target_name} PRIVATE -flto) - - # endif () + target_link_options(${target_name} PRIVATE -Wl,--gc-sections) + target_link_options(${target_name} PRIVATE -Wl,--exclude-libs,ALL) + target_link_options(${target_name} PRIVATE -flto) + endif () endfunction() -build_library("rnllama") # Default target +# Default target (no specific CPU features) +build_library("rnllama" "") if (${ANDROID_ABI} STREQUAL "arm64-v8a") - build_library("rnllama_v8fp16_va") + # ARM64 targets + build_library("rnllama_v8_4_fp16_dotprod" "-march=armv8.4-a+fp16+dotprod") + build_library("rnllama_v8_2_fp16_dotprod" "-march=armv8.2-a+fp16+dotprod") + build_library("rnllama_v8_2_fp16" "-march=armv8.2-a+fp16") + build_library("rnllama_v8" "-march=armv8-a") +elseif (${ANDROID_ABI} STREQUAL "x86_64") + # x86_64 target + build_library("rnllama_x86_64" "-march=x86-64" "-mtune=intel" "-msse4.2" "-mpopcnt") + endif () diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index c37fa38a..e2044cf1 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -237,29 +237,33 @@ public void release() { static { Log.d(NAME, "Primary ABI: " + Build.SUPPORTED_ABIS[0]); if (LlamaContext.isArm64V8a()) { - boolean loadV8fp16 = false; - if (LlamaContext.isArm64V8a()) { - // ARMv8.2a needs runtime detection support - String cpuInfo = LlamaContext.cpuInfo(); - if (cpuInfo != null) { - Log.d(NAME, "CPU info: " + cpuInfo); - if (cpuInfo.contains("fphp")) { - Log.d(NAME, "CPU supports fp16 arithmetic"); - loadV8fp16 = true; - } - } - } + String cpuFeatures = LlamaContext.getCpuFeatures(); + Log.d(NAME, "CPU features: " + cpuFeatures); - if (loadV8fp16) { - Log.d(NAME, "Loading librnllama_v8fp16_va.so"); - System.loadLibrary("rnllama_v8fp16_va"); - } else { - Log.d(NAME, "Loading librnllama.so"); - System.loadLibrary("rnllama"); - } + boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp"); + boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp"); + boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes"); + boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat"); + + if (isAtLeastArmV84 && hasFp16 && hasDotProd) { + Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod.so"); + System.loadLibrary("rnllama_v8_4_fp16_dotprod"); + } else if (isAtLeastArmV82 && hasFp16 && hasDotProd) { + Log.d(NAME, "Loading librnllama_v8_2_fp16_dotprod.so"); + System.loadLibrary("rnllama_v8_2_fp16_dotprod"); + } else if (isAtLeastArmV82 && hasFp16) { + Log.d(NAME, "Loading librnllama_v8_2_fp16.so"); + System.loadLibrary("rnllama_v8_2_fp16"); + } else { + Log.d(NAME, "Loading librnllama_v8.so"); + System.loadLibrary("rnllama_v8"); + } } else if (LlamaContext.isX86_64()) { - Log.d(NAME, "Loading librnllama.so"); - System.loadLibrary("rnllama"); + Log.d(NAME, "Loading librnllama_x86_64.so"); + System.loadLibrary("rnllama_x86_64"); + } else { + Log.d(NAME, "Loading default librnllama.so"); + System.loadLibrary("rnllama"); } } @@ -271,20 +275,23 @@ private static boolean isX86_64() { return Build.SUPPORTED_ABIS[0].equals("x86_64"); } - private static String cpuInfo() { + private static String getCpuFeatures() { File file = new File("/proc/cpuinfo"); StringBuilder stringBuilder = new StringBuilder(); try { BufferedReader bufferedReader = new BufferedReader(new FileReader(file)); String line; while ((line = bufferedReader.readLine()) != null) { + if (line.startsWith("Features")) { stringBuilder.append(line); + break; + } } bufferedReader.close(); return stringBuilder.toString(); } catch (IOException e) { Log.w(NAME, "Couldn't read /proc/cpuinfo", e); - return null; + return ""; } } diff --git a/cpp/common.cpp b/cpp/common.cpp index e4342ce5..9e2ffb3e 100644 --- a/cpp/common.cpp +++ b/cpp/common.cpp @@ -1,3 +1,7 @@ +#if defined(_MSC_VER) +#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING +#endif + #include "common.h" // Change JSON_ASSERT from assert() to LM_GGML_ASSERT: #define JSON_ASSERT LM_GGML_ASSERT @@ -6,21 +10,21 @@ #include "llama.h" #include -#include +#include #include +#include +#include #include #include #include -#include #include +#include #include #include #include #include #include #include -#include -#include #if defined(__APPLE__) && defined(__MACH__) #include @@ -196,6 +200,12 @@ int32_t cpu_get_num_math() { // CLI argument parsing // +void gpt_params_handle_hf_token(gpt_params & params) { + if (params.hf_token.empty() && std::getenv("HF_TOKEN")) { + params.hf_token = std::getenv("HF_TOKEN"); + } +} + void gpt_params_handle_model_default(gpt_params & params) { if (!params.hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model @@ -205,19 +215,13 @@ void gpt_params_handle_model_default(gpt_params & params) { } params.hf_file = params.model; } else if (params.model.empty()) { - std::string cache_directory = fs_get_cache_directory(); - const bool success = fs_create_directory_with_parents(cache_directory); - if (!success) { - throw std::runtime_error("failed to create cache directory: " + cache_directory); - } - params.model = cache_directory + string_split(params.hf_file, '/').back(); + params.model = fs_get_cache_file(string_split(params.hf_file, '/').back()); } } else if (!params.model_url.empty()) { if (params.model.empty()) { auto f = string_split(params.model_url, '#').front(); f = string_split(f, '?').front(); - f = string_split(f, '/').back(); - params.model = "models/" + f; + params.model = fs_get_cache_file(string_split(f, '/').back()); } } else if (params.model.empty()) { params.model = DEFAULT_MODEL_PATH; @@ -243,15 +247,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } } - if (params.prompt_cache_all && - (params.interactive || params.interactive_first || - params.instruct)) { - + if (params.prompt_cache_all && (params.interactive || params.interactive_first)) { throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); } gpt_params_handle_model_default(params); + gpt_params_handle_hf_token(params); + if (params.escape) { string_process_escapes(params.prompt); string_process_escapes(params.input_prefix); @@ -271,39 +274,39 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { - bool result = true; + const auto params_org = params; // the example can modify the default params + try { - if (!gpt_params_parse_ex(argc, argv, params)) { - gpt_params_print_usage(argc, argv, gpt_params()); - exit(0); + if (!gpt_params_parse_ex(argc, argv, params) || params.usage) { + params = params_org; + params.usage = true; + return false; } - } - catch (const std::invalid_argument & ex) { + } catch (const std::invalid_argument & ex) { fprintf(stderr, "%s\n", ex.what()); - gpt_params_print_usage(argc, argv, gpt_params()); - exit(1); + params = params_org; + return false; } - return result; + + return true; } +#define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; } + bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { + const char split_delim = ','; + llama_sampling_params & sparams = params.sparams; if (arg == "-s" || arg == "--seed") { - if (++i >= argc) { - invalid_param = true; - return true; - } - // This is temporary, in the future the samplign state will be moved fully to llama_sampling_context. + CHECK_ARG + // TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context. params.seed = std::stoul(argv[i]); sparams.seed = std::stoul(argv[i]); return true; } if (arg == "-t" || arg == "--threads") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_threads = std::stoi(argv[i]); if (params.n_threads <= 0) { params.n_threads = std::thread::hardware_concurrency(); @@ -311,10 +314,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-tb" || arg == "--threads-batch") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_threads_batch = std::stoi(argv[i]); if (params.n_threads_batch <= 0) { params.n_threads_batch = std::thread::hardware_concurrency(); @@ -322,10 +322,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-td" || arg == "--threads-draft") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_threads_draft = std::stoi(argv[i]); if (params.n_threads_draft <= 0) { params.n_threads_draft = std::thread::hardware_concurrency(); @@ -333,10 +330,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-tbd" || arg == "--threads-batch-draft") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_threads_batch_draft = std::stoi(argv[i]); if (params.n_threads_batch_draft <= 0) { params.n_threads_batch_draft = std::thread::hardware_concurrency(); @@ -344,10 +338,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-p" || arg == "--prompt") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.prompt = argv[i]; return true; } @@ -355,11 +346,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.escape = true; return true; } + if (arg == "--no-escape") { + params.escape = false; + return true; + } if (arg == "--prompt-cache") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.path_prompt_cache = argv[i]; return true; } @@ -372,10 +364,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-bf" || arg == "--binary-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i], std::ios::binary); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -391,10 +380,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-f" || arg == "--file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i]); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -409,67 +395,54 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } return true; } - if (arg == "-n" || arg == "--n-predict") { - if (++i >= argc) { + if (arg == "--in-file") { + CHECK_ARG + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); invalid_param = true; return true; } + params.in_files.push_back(argv[i]); + return true; + } + if (arg == "-n" || arg == "--predict" || arg == "--n-predict") { + CHECK_ARG params.n_predict = std::stoi(argv[i]); return true; } if (arg == "--top-k") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.top_k = std::stoi(argv[i]); return true; } if (arg == "-c" || arg == "--ctx-size") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_ctx = std::stoi(argv[i]); return true; } if (arg == "--grp-attn-n" || arg == "-gan") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.grp_attn_n = std::stoi(argv[i]); return true; } if (arg == "--grp-attn-w" || arg == "-gaw") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.grp_attn_w = std::stoi(argv[i]); return true; } if (arg == "--rope-freq-base") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.rope_freq_base = std::stof(argv[i]); return true; } if (arg == "--rope-freq-scale") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.rope_freq_scale = std::stof(argv[i]); return true; } if (arg == "--rope-scaling") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::string value(argv[i]); /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; } else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } @@ -478,217 +451,148 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--rope-scale") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.rope_freq_scale = 1.0f / std::stof(argv[i]); return true; } if (arg == "--yarn-orig-ctx") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.yarn_orig_ctx = std::stoi(argv[i]); return true; } if (arg == "--yarn-ext-factor") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.yarn_ext_factor = std::stof(argv[i]); return true; } if (arg == "--yarn-attn-factor") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.yarn_attn_factor = std::stof(argv[i]); return true; } if (arg == "--yarn-beta-fast") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.yarn_beta_fast = std::stof(argv[i]); return true; } if (arg == "--yarn-beta-slow") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.yarn_beta_slow = std::stof(argv[i]); return true; } if (arg == "--pooling") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::string value(argv[i]); /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } + else { invalid_param = true; } + return true; + } + if (arg == "--attention") { + CHECK_ARG + std::string value(argv[i]); + /**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; } + else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL; } else { invalid_param = true; } return true; } if (arg == "--defrag-thold" || arg == "-dt") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.defrag_thold = std::stof(argv[i]); return true; } if (arg == "--samplers") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG const auto sampler_names = string_split(argv[i], ';'); sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true); return true; } if (arg == "--sampling-seq") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.samplers_sequence = llama_sampling_types_from_chars(argv[i]); return true; } if (arg == "--top-p") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.top_p = std::stof(argv[i]); return true; } if (arg == "--min-p") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.min_p = std::stof(argv[i]); return true; } if (arg == "--temp") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.temp = std::stof(argv[i]); sparams.temp = std::max(sparams.temp, 0.0f); return true; } if (arg == "--tfs") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.tfs_z = std::stof(argv[i]); return true; } if (arg == "--typical") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.typical_p = std::stof(argv[i]); return true; } if (arg == "--repeat-last-n") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.penalty_last_n = std::stoi(argv[i]); sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n); return true; } if (arg == "--repeat-penalty") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.penalty_repeat = std::stof(argv[i]); return true; } if (arg == "--frequency-penalty") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.penalty_freq = std::stof(argv[i]); return true; } if (arg == "--presence-penalty") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.penalty_present = std::stof(argv[i]); return true; } if (arg == "--dynatemp-range") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.dynatemp_range = std::stof(argv[i]); return true; } if (arg == "--dynatemp-exp") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.dynatemp_exponent = std::stof(argv[i]); return true; } if (arg == "--mirostat") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.mirostat = std::stoi(argv[i]); return true; } if (arg == "--mirostat-lr") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.mirostat_eta = std::stof(argv[i]); return true; } if (arg == "--mirostat-ent") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.mirostat_tau = std::stof(argv[i]); return true; } if (arg == "--cfg-negative-prompt") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.cfg_negative_prompt = argv[i]; return true; } if (arg == "--cfg-negative-prompt-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i]); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -702,203 +606,131 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--cfg-scale") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.cfg_scale = std::stof(argv[i]); return true; } if (arg == "-b" || arg == "--batch-size") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_batch = std::stoi(argv[i]); return true; } if (arg == "-ub" || arg == "--ubatch-size") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_ubatch = std::stoi(argv[i]); return true; } if (arg == "--keep") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_keep = std::stoi(argv[i]); return true; } if (arg == "--draft") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_draft = std::stoi(argv[i]); return true; } if (arg == "--chunks") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_chunks = std::stoi(argv[i]); return true; } if (arg == "-np" || arg == "--parallel") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_parallel = std::stoi(argv[i]); return true; } if (arg == "-ns" || arg == "--sequences") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_sequences = std::stoi(argv[i]); return true; } if (arg == "--p-split" || arg == "-ps") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.p_split = std::stof(argv[i]); return true; } if (arg == "-m" || arg == "--model") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.model = argv[i]; return true; } if (arg == "-md" || arg == "--model-draft") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.model_draft = argv[i]; return true; } if (arg == "-a" || arg == "--alias") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.model_alias = argv[i]; return true; } if (arg == "-mu" || arg == "--model-url") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.model_url = argv[i]; return true; } - if (arg == "-hfr" || arg == "--hf-repo") { + if (arg == "-hft" || arg == "--hf-token") { if (++i >= argc) { - invalid_param = true; - return true; + invalid_param = true; + return true; } + params.hf_token = argv[i]; + return true; + } + if (arg == "-hfr" || arg == "--hf-repo") { + CHECK_ARG params.hf_repo = argv[i]; return true; } if (arg == "-hff" || arg == "--hf-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.hf_file = argv[i]; return true; } if (arg == "--lora") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.lora_adapter.emplace_back(argv[i], 1.0f); - params.use_mmap = false; return true; } if (arg == "--lora-scaled") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG const char* lora_adapter = argv[i]; - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); - params.use_mmap = false; return true; } if (arg == "--lora-base") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.lora_base = argv[i]; return true; } if (arg == "--control-vector") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.control_vectors.push_back({ 1.0f, argv[i], }); return true; } if (arg == "--control-vector-scaled") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG const char* fname = argv[i]; - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.control_vectors.push_back({ std::stof(argv[i]), fname, }); return true; } if (arg == "--control-vector-layer-range") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.control_vector_layer_start = std::stoi(argv[i]); - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.control_vector_layer_end = std::stoi(argv[i]); return true; } if (arg == "--mmproj") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.mmproj = argv[i]; return true; } if (arg == "--image") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.image.emplace_back(argv[i]); return true; } @@ -906,32 +738,35 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.interactive = true; return true; } - if (arg == "--interactive-specials") { - params.interactive_specials = true; - return true; - } - if (arg == "--special") { + if (arg == "-sp" || arg == "--special") { params.special = true; return true; } - if (arg == "--embedding") { + if (arg == "--embedding" || arg == "--embeddings") { params.embedding = true; return true; } - if (arg == "--interactive-first") { - params.interactive_first = true; + if (arg == "--embd-normalize") { + CHECK_ARG + params.embd_normalize = std::stoi(argv[i]); return true; } - if (arg == "-ins" || arg == "--instruct") { - params.instruct = true; + if (arg == "--embd-output-format") { + CHECK_ARG + params.embd_out = argv[i]; return true; } - if (arg == "-cnv" || arg == "--conversation") { - params.conversation = true; + if (arg == "--embd-separator") { + CHECK_ARG + params.embd_sep = argv[i]; return true; } - if (arg == "-cml" || arg == "--chatml") { - params.chatml = true; + if (arg == "-if" || arg == "--interactive-first") { + params.interactive_first = true; + return true; + } + if (arg == "-cnv" || arg == "--conversation") { + params.conversation = true; return true; } if (arg == "--infill") { @@ -954,7 +789,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cache_type_v = argv[++i]; return true; } - if (arg == "--multiline-input") { + if (arg == "-mli" || arg == "--multiline-input") { params.multiline_input = true; return true; } @@ -966,11 +801,15 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cont_batching = true; return true; } + if (arg == "-nocb" || arg == "--no-cont-batching") { + params.cont_batching = false; + return true; + } if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; return true; } - if (arg == "--color") { + if (arg == "-co" || arg == "--color") { params.use_color = true; return true; } @@ -978,46 +817,34 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.use_mlock = true; return true; } - if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { - if (++i >= argc) { - invalid_param = true; - return true; - } + if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") { + CHECK_ARG params.n_gpu_layers = std::stoi(argv[i]); if (!llama_supports_gpu_offload()) { - fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); + fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers option will be ignored\n"); fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); } return true; } - if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") { - if (++i >= argc) { - invalid_param = true; - return true; - } + if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--gpu-layers-draft") { + CHECK_ARG params.n_gpu_layers_draft = std::stoi(argv[i]); if (!llama_supports_gpu_offload()) { - fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n"); + fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n"); fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); } return true; } if (arg == "--main-gpu" || arg == "-mg") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.main_gpu = std::stoi(argv[i]); -#ifndef LM_GGML_USE_CUDA_SYCL - fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL. Setting the main GPU has no effect.\n"); -#endif // LM_GGML_USE_CUDA_SYCL +#ifndef LM_GGML_USE_CUDA_SYCL_VULKAN + fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting the main GPU has no effect.\n"); +#endif // LM_GGML_USE_CUDA_SYCL_VULKAN return true; } if (arg == "--split-mode" || arg == "-sm") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::string arg_next = argv[i]; if (arg_next == "none") { params.split_mode = LLAMA_SPLIT_MODE_NONE; @@ -1036,16 +863,13 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa invalid_param = true; return true; } -#ifndef LM_GGML_USE_CUDA_SYCL - fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL. Setting the split mode has no effect.\n"); -#endif // LM_GGML_USE_CUDA_SYCL +#ifndef LM_GGML_USE_CUDA_SYCL_VULKAN + fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting the split mode has no effect.\n"); +#endif // LM_GGML_USE_CUDA_SYCL_VULKAN return true; } if (arg == "--tensor-split" || arg == "-ts") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::string arg_next = argv[i]; // split string by , and / @@ -1070,10 +894,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--rpc") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.rpc_servers = argv[i]; return true; } @@ -1082,10 +903,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--numa") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::string value(argv[i]); /**/ if (value == "distribute" || value == "") { params.numa = LM_GGML_NUMA_STRATEGY_DISTRIBUTE; } else if (value == "isolate") { params.numa = LM_GGML_NUMA_STRATEGY_ISOLATE; } @@ -1093,6 +911,15 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa else { invalid_param = true; } return true; } + if (arg == "-v" || arg == "--verbose") { + params.verbosity = 1; + return true; + } + if (arg == "--verbosity") { + CHECK_ARG + params.verbosity = std::stoi(argv[i]); + return true; + } if (arg == "--verbose-prompt") { params.verbose_prompt = true; return true; @@ -1102,18 +929,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-r" || arg == "--reverse-prompt") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.antiprompt.emplace_back(argv[i]); return true; } if (arg == "-ld" || arg == "--logdir") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.logdir = argv[i]; if (params.logdir.back() != DIRECTORY_SEPARATOR) { @@ -1122,209 +943,395 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-lcs" || arg == "--lookup-cache-static") { - if (++i >= argc) { + CHECK_ARG + params.lookup_cache_static = argv[i]; + return true; + } + if (arg == "-lcd" || arg == "--lookup-cache-dynamic") { + CHECK_ARG + params.lookup_cache_dynamic = argv[i]; + return true; + } + if (arg == "--save-all-logits" || arg == "--kl-divergence-base") { + CHECK_ARG + params.logits_file = argv[i]; + return true; + } + if (arg == "--perplexity" || arg == "--all-logits") { + params.logits_all = true; + return true; + } + if (arg == "--ppl-stride") { + CHECK_ARG + params.ppl_stride = std::stoi(argv[i]); + return true; + } + if (arg == "--ppl-output-type") { + CHECK_ARG + params.ppl_output_type = std::stoi(argv[i]); + return true; + } + if (arg == "-ptc" || arg == "--print-token-count") { + CHECK_ARG + params.n_print = std::stoi(argv[i]); + return true; + } + if (arg == "--check-tensors") { + params.check_tensors = true; + return true; + } + if (arg == "--hellaswag") { + params.hellaswag = true; + return true; + } + if (arg == "--hellaswag-tasks") { + CHECK_ARG + params.hellaswag_tasks = std::stoi(argv[i]); + return true; + } + if (arg == "--winogrande") { + params.winogrande = true; + return true; + } + if (arg == "--winogrande-tasks") { + CHECK_ARG + params.winogrande_tasks = std::stoi(argv[i]); + return true; + } + if (arg == "--multiple-choice") { + params.multiple_choice = true; + return true; + } + if (arg == "--multiple-choice-tasks") { + CHECK_ARG + params.multiple_choice_tasks = std::stoi(argv[i]); + return true; + } + if (arg == "--kl-divergence") { + params.kl_divergence = true; + return true; + } + if (arg == "--ignore-eos") { + params.ignore_eos = true; + return true; + } + if (arg == "--penalize-nl") { + sparams.penalize_nl = true; + return true; + } + if (arg == "-l" || arg == "--logit-bias") { + CHECK_ARG + std::stringstream ss(argv[i]); + llama_token key; + char sign; + std::string value_str; + try { + if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { + sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); + } + else { + throw std::exception(); + } + } + catch (const std::exception&) { invalid_param = true; return true; } - params.lookup_cache_static = argv[i]; return true; } - if (arg == "-lcd" || arg == "--lookup-cache-dynamic") { - if (++i >= argc) { + if (arg == "-h" || arg == "--help" || arg == "--usage" ) { + params.usage = true; + return true; + } + if (arg == "--version") { + fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT); + fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET); + exit(0); + } + if (arg == "--in-prefix-bos") { + params.input_prefix_bos = true; + params.enable_chat_template = false; + return true; + } + if (arg == "--in-prefix") { + CHECK_ARG + params.input_prefix = argv[i]; + params.enable_chat_template = false; + return true; + } + if (arg == "--in-suffix") { + CHECK_ARG + params.input_suffix = argv[i]; + params.enable_chat_template = false; + return true; + } + if (arg == "--spm-infill") { + params.spm_infill = true; + return true; + } + if (arg == "--grammar") { + CHECK_ARG + sparams.grammar = argv[i]; + return true; + } + if (arg == "--grammar-file") { + CHECK_ARG + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); invalid_param = true; return true; } - params.lookup_cache_dynamic = argv[i]; + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + std::back_inserter(sparams.grammar) + ); return true; } - if (arg == "--save-all-logits" || arg == "--kl-divergence-base") { - if (++i >= argc) { + if (arg == "-j" || arg == "--json-schema") { + CHECK_ARG + sparams.grammar = json_schema_to_grammar(json::parse(argv[i])); + return true; + } + if (arg == "--override-kv") { + CHECK_ARG + if (!string_parse_kv_override(argv[i], params.kv_overrides)) { + fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); invalid_param = true; return true; } - params.logits_file = argv[i]; return true; } - if (arg == "--perplexity" || arg == "--all-logits") { - params.logits_all = true; + if (arg == "--host") { + CHECK_ARG + params.hostname = argv[i]; return true; } - if (arg == "--ppl-stride") { - if (++i >= argc) { + if (arg == "--port") { + CHECK_ARG + params.port = std::stoi(argv[i]); + return true; + } + if (arg == "--path") { + CHECK_ARG + params.public_path = argv[i]; + return true; + } + if (arg == "--api-key") { + CHECK_ARG + params.api_keys.push_back(argv[i]); + return true; + } + if (arg == "--api-key-file") { + CHECK_ARG + std::ifstream key_file(argv[i]); + if (!key_file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); invalid_param = true; return true; } - params.ppl_stride = std::stoi(argv[i]); + std::string key; + while (std::getline(key_file, key)) { + if (!key.empty()) { + params.api_keys.push_back(key); + } + } + key_file.close(); + return true; + } + if (arg == "--ssl-key-file") { + CHECK_ARG + params.ssl_file_key = argv[i]; + return true; + } + if (arg == "--ssl-cert-file") { + CHECK_ARG + params.ssl_file_cert = argv[i]; + return true; + } + if (arg == "--timeout" || arg == "-to") { + CHECK_ARG + params.timeout_read = std::stoi(argv[i]); + params.timeout_write = std::stoi(argv[i]); + return true; + } + if (arg == "--threads-http") { + CHECK_ARG + params.n_threads_http = std::stoi(argv[i]); + return true; + } + if (arg == "-spf" || arg == "--system-prompt-file") { + CHECK_ARG + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + return true; + } + std::string system_prompt; + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + std::back_inserter(system_prompt) + ); + params.system_prompt = system_prompt; + return true; + } + if (arg == "--log-format") { + CHECK_ARG + if (std::strcmp(argv[i], "json") == 0) { + params.log_json = true; + } else if (std::strcmp(argv[i], "text") == 0) { + params.log_json = false; + } else { + invalid_param = true; + return true; + } + return true; + } + if (arg == "--no-slots") { + params.endpoint_slots = false; + return true; + } + if (arg == "--metrics") { + params.endpoint_metrics = true; + return true; + } + if (arg == "--slot-save-path") { + CHECK_ARG + params.slot_save_path = argv[i]; + // if doesn't end with DIRECTORY_SEPARATOR, add it + if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) { + params.slot_save_path += DIRECTORY_SEPARATOR; + } return true; } - if (arg == "-ptc" || arg == "--print-token-count") { - if (++i >= argc) { + if (arg == "--chat-template") { + CHECK_ARG + if (!llama_chat_verify_template(argv[i])) { + fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]); + fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n"); invalid_param = true; return true; } - params.n_print = std::stoi(argv[i]); + params.chat_template = argv[i]; return true; } - if (arg == "--check-tensors") { - params.check_tensors = true; + if (arg == "--slot-prompt-similarity" || arg == "-sps") { + CHECK_ARG + params.slot_prompt_similarity = std::stof(argv[i]); return true; } - if (arg == "--ppl-output-type") { - if (++i >= argc) { - invalid_param = true; - return true; - } - params.ppl_output_type = std::stoi(argv[i]); + if (arg == "-pps") { + params.is_pp_shared = true; return true; } - if (arg == "--hellaswag") { - params.hellaswag = true; + if (arg == "-npp") { + CHECK_ARG + auto p = string_split(argv[i], split_delim); + params.n_pp.insert(params.n_pp.end(), p.begin(), p.end()); return true; } - if (arg == "--hellaswag-tasks") { - if (++i >= argc) { - invalid_param = true; - return true; - } - params.hellaswag_tasks = std::stoi(argv[i]); + if (arg == "-ntg") { + CHECK_ARG + auto p = string_split(argv[i], split_delim); + params.n_tg.insert(params.n_tg.end(), p.begin(), p.end()); return true; } - if (arg == "--winogrande") { - params.winogrande = true; + if (arg == "-npl") { + CHECK_ARG + auto p = string_split(argv[i], split_delim); + params.n_pl.insert(params.n_pl.end(), p.begin(), p.end()); return true; } - if (arg == "--winogrande-tasks") { - if (++i >= argc) { + if (arg == "--context-file") { + CHECK_ARG + std::ifstream file(argv[i], std::ios::binary); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); invalid_param = true; return true; } - params.winogrande_tasks = std::stoi(argv[i]); + params.context_files.push_back(argv[i]); return true; } - if (arg == "--multiple-choice") { - params.multiple_choice = true; + if (arg == "--chunk-size") { + CHECK_ARG + params.chunk_size = std::stoi(argv[i]); return true; } - if (arg == "--multiple-choice-tasks") { - if (++i >= argc) { - invalid_param = true; - return true; - } - params.multiple_choice_tasks = std::stoi(argv[i]); + if (arg == "--chunk-separator") { + CHECK_ARG + params.chunk_separator = argv[i]; return true; } - if (arg == "--kl-divergence") { - params.kl_divergence = true; + if (arg == "--junk") { + CHECK_ARG + params.n_junk = std::stoi(argv[i]); return true; } - if (arg == "--ignore-eos") { - params.ignore_eos = true; + if (arg == "--pos") { + CHECK_ARG + params.i_pos = std::stoi(argv[i]); return true; } - if (arg == "--penalize-nl") { - sparams.penalize_nl = true; + if (arg == "-o" || arg == "--output" || arg == "--output-file") { + CHECK_ARG + params.out_file = argv[i]; + params.cvector_outfile = argv[i]; return true; } - if (arg == "-l" || arg == "--logit-bias") { - if (++i >= argc) { - invalid_param = true; - return true; - } - std::stringstream ss(argv[i]); - llama_token key; - char sign; - std::string value_str; - try { - if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { - sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); - } - else { - throw std::exception(); - } - } - catch (const std::exception&) { - invalid_param = true; - return true; - } + if (arg == "-ofreq" || arg == "--output-frequency") { + CHECK_ARG + params.n_out_freq = std::stoi(argv[i]); return true; } - if (arg == "-h" || arg == "--help") { - gpt_params_print_usage(argc, argv, gpt_params()); - exit(0); - } - if (arg == "--version") { - fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT); - fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET); - exit(0); + if (arg == "--save-frequency") { + CHECK_ARG + params.n_save_freq = std::stoi(argv[i]); + return true; } - if (arg == "--random-prompt") { - params.random_prompt = true; + if (arg == "--process-output") { + params.process_output = true; return true; } - if (arg == "--in-prefix-bos") { - params.input_prefix_bos = true; + if (arg == "--no-ppl") { + params.compute_ppl = false; return true; } - if (arg == "--in-prefix") { - if (++i >= argc) { - invalid_param = true; - return true; - } - params.input_prefix = argv[i]; + if (arg == "--chunk" || arg == "--from-chunk") { + CHECK_ARG + params.i_chunk = std::stoi(argv[i]); return true; } - if (arg == "--in-suffix") { - if (++i >= argc) { - invalid_param = true; - return true; - } - params.input_suffix = argv[i]; + // cvector params + if (arg == "--positive-file") { + CHECK_ARG + params.cvector_positive_file = argv[i]; return true; } - if (arg == "--grammar") { - if (++i >= argc) { - invalid_param = true; - return true; - } - sparams.grammar = argv[i]; + if (arg == "--negative-file") { + CHECK_ARG + params.cvector_negative_file = argv[i]; return true; } - if (arg == "--grammar-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } - std::ifstream file(argv[i]); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - return true; - } - std::copy( - std::istreambuf_iterator(file), - std::istreambuf_iterator(), - std::back_inserter(sparams.grammar) - ); + if (arg == "--pca-batch") { + CHECK_ARG + params.n_pca_batch = std::stoi(argv[i]); return true; } - if (arg == "-j" || arg == "--json-schema") { - if (++i >= argc) { - invalid_param = true; - return true; - } - sparams.grammar = json_schema_to_grammar(json::parse(argv[i])); + if (arg == "--pca-iter") { + CHECK_ARG + params.n_pca_iterations = std::stoi(argv[i]); return true; } - if (arg == "--override-kv") { - if (++i >= argc) { - invalid_param = true; - return true; - } - if (!string_parse_kv_override(argv[i], params.kv_overrides)) { - fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); - invalid_param = true; - return true; - } + if (arg == "--method") { + CHECK_ARG + std::string value(argv[i]); + /**/ if (value == "pca") { params.cvector_dimre_method = DIMRE_METHOD_PCA; } + else if (value == "mean") { params.cvector_dimre_method = DIMRE_METHOD_MEAN; } + else { invalid_param = true; } return true; } #ifndef LOG_DISABLE_LOGS @@ -1338,10 +1345,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa // We have a matching known parameter requiring an argument, // now we need to check if there is anything after this argv // and flag invalid_param or parse it. - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG if (!log_param_pair_parse( /*check_but_dont_parse*/ false, argv[i - 1], argv[i])) { invalid_param = true; return true; @@ -1354,6 +1358,16 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return false; } +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) +#endif + void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { const llama_sampling_params & sparams = params.sparams; @@ -1365,198 +1379,333 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param } sampler_type_names.pop_back(); - printf("\n"); - printf("usage: %s [options]\n", argv[0]); - printf("\n"); - printf("options:\n"); - printf(" -h, --help show this help message and exit\n"); - printf(" --version show version and build info\n"); - printf(" -i, --interactive run in interactive mode\n"); - printf(" --special special tokens output enabled\n"); - printf(" --interactive-specials allow special tokens in user text, in interactive mode\n"); - printf(" --interactive-first run in interactive mode and wait for input right away\n"); - printf(" -cnv, --conversation run in conversation mode (does not print special tokens and suffix/prefix)\n"); - printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); - printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n"); - printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); - printf(" -r PROMPT, --reverse-prompt PROMPT\n"); - printf(" halt generation at PROMPT, return control in interactive mode\n"); - printf(" (can be specified more than once for multiple prompts).\n"); - printf(" --color colorise output to distinguish prompt and user input from generations\n"); - printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); - printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads); - printf(" -tb N, --threads-batch N\n"); - printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n"); - printf(" -td N, --threads-draft N"); - printf(" number of threads to use during generation (default: same as --threads)\n"); - printf(" -tbd N, --threads-batch-draft N\n"); - printf(" number of threads to use during batch and prompt processing (default: same as --threads-draft)\n"); - printf(" -p PROMPT, --prompt PROMPT\n"); - printf(" prompt to start generation with (default: empty)\n"); - printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); - printf(" --prompt-cache FNAME file to cache prompt state for faster startup (default: none)\n"); - printf(" --prompt-cache-all if specified, saves user input and generations to cache as well.\n"); - printf(" not supported with --interactive or other interactive options\n"); - printf(" --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n"); - printf(" --random-prompt start with a randomized prompt.\n"); - printf(" --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string\n"); - printf(" --in-prefix STRING string to prefix user inputs with (default: empty)\n"); - printf(" --in-suffix STRING string to suffix after user inputs with (default: empty)\n"); - printf(" -f FNAME, --file FNAME\n"); - printf(" prompt file to start generation.\n"); - printf(" -bf FNAME, --binary-file FNAME\n"); - printf(" binary file containing multiple choice tasks.\n"); - printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); - printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); - printf(" -b N, --batch-size N logical maximum batch size (default: %d)\n", params.n_batch); - printf(" -ub N, --ubatch-size N\n"); - printf(" physical maximum batch size (default: %d)\n", params.n_ubatch); - printf(" --samplers samplers that will be used for generation in the order, separated by \';\'\n"); - printf(" (default: %s)\n", sampler_type_names.c_str()); - printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sampler_type_chars.c_str()); - printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); - printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); - printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p); - printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); - printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); - printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); - printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat); - printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present); - printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq); - printf(" --dynatemp-range N dynamic temperature range (default: %.1f, 0.0 = disabled)\n", (double)sparams.dynatemp_range); - printf(" --dynatemp-exp N dynamic temperature exponent (default: %.1f)\n", (double)sparams.dynatemp_exponent); - printf(" --mirostat N use Mirostat sampling.\n"); - printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); - printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat); - printf(" --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)sparams.mirostat_eta); - printf(" --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)sparams.mirostat_tau); - printf(" -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n"); - printf(" modifies the likelihood of token appearing in the completion,\n"); - printf(" i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); - printf(" or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); - printf(" --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n"); - printf(" --grammar-file FNAME file to read grammar from\n"); - printf(" -j SCHEMA, --json-schema SCHEMA\n"); - printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n"); - printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n"); - printf(" --cfg-negative-prompt PROMPT\n"); - printf(" negative prompt to use for guidance. (default: empty)\n"); - printf(" --cfg-negative-prompt-file FNAME\n"); - printf(" negative prompt file to use for guidance. (default: empty)\n"); - printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale); - printf(" --rope-scaling {none,linear,yarn}\n"); - printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n"); - printf(" --rope-scale N RoPE context scaling factor, expands context by a factor of N\n"); - printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n"); - printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n"); - printf(" --yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size)\n"); - printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n"); - printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); - printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); - printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); - printf(" --pooling {none,mean,cls}\n"); - printf(" pooling type for embeddings, use model default if unspecified\n"); - printf(" -dt N, --defrag-thold N\n"); - printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold); - printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); - printf(" --penalize-nl penalize newline tokens\n"); - printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp); - printf(" --all-logits return logits for all tokens in the batch (default: disabled)\n"); - printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); - printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks); - printf(" --winogrande compute Winogrande score over random tasks from datafile supplied with -f\n"); - printf(" --winogrande-tasks N number of tasks to use when computing the Winogrande score (default: %zu)\n", params.winogrande_tasks); - printf(" --multiple-choice compute multiple choice score over random tasks from datafile supplied with -f\n"); - printf(" --multiple-choice-tasks N number of tasks to use when computing the multiple choice score (default: %zu)\n", params.winogrande_tasks); - printf(" --kl-divergence computes KL-divergence to logits provided via --kl-divergence-base\n"); - printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); - printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft); - printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); - printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel); - printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); - printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split); - printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); - printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled"); - printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); - printf(" --image IMAGE_FILE path to an image file. use with multimodal models. Specify multiple times for batching\n"); + struct option_info { + LLAMA_COMMON_ATTRIBUTE_FORMAT(4, 5) + option_info(const std::string & tags, const char * args, const char * desc, ...) : tags(tags), args(args), desc(desc) { + va_list args_list; + va_start(args_list, desc); + char buffer[1024]; + vsnprintf(buffer, sizeof(buffer), desc, args_list); + va_end(args_list); + this->desc = buffer; + } + + option_info(const std::string & grp) : grp(grp) {} + + std::string tags; + std::string args; + std::string desc; + std::string grp; + }; + + std::vector options; + + // TODO: filter by tags + + options.push_back({ "general" }); + options.push_back({ "*", "-h, --help, --usage", "print usage and exit" }); + options.push_back({ "*", " --version", "show version and build info" }); + options.push_back({ "*", "-v, --verbose", "print verbose information" }); + options.push_back({ "*", " --verbosity N", "set specific verbosity level (default: %d)", params.verbosity }); + options.push_back({ "*", " --verbose-prompt", "print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false" }); + options.push_back({ "*", " --no-display-prompt", "don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false" }); + options.push_back({ "*", "-co, --color", "colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false" }); + options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", params.seed }); + options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.n_threads }); + options.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" }); + options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" }); + options.push_back({ "speculative", "-tbd, --threads-batch-draft N", + "number of threads to use during batch and prompt processing (default: same as --threads-draft)" }); + options.push_back({ "speculative", " --draft N", "number of tokens to draft for speculative decoding (default: %d)", params.n_draft }); + options.push_back({ "speculative", "-ps, --p-split N", "speculative decoding split probability (default: %.1f)", (double)params.p_split }); + options.push_back({ "*", "-lcs, --lookup-cache-static FNAME", + "path to static lookup cache to use for lookup decoding (not updated by generation)" }); + options.push_back({ "*", "-lcd, --lookup-cache-dynamic FNAME", + "path to dynamic lookup cache to use for lookup decoding (updated by generation)" }); + + options.push_back({ "*", "-c, --ctx-size N", "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx }); + options.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict }); + options.push_back({ "*", "-b, --batch-size N", "logical maximum batch size (default: %d)", params.n_batch }); + options.push_back({ "*", "-ub, --ubatch-size N", "physical maximum batch size (default: %d)", params.n_ubatch }); + options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep }); + options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); + options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); + options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" + "in conversation mode, this will be used as system prompt\n" + "(default: '%s')", params.prompt.c_str() }); + options.push_back({ "*", "-f, --file FNAME", "a file containing the prompt (default: none)" }); + options.push_back({ "*", " --in-file FNAME", "an input file (repeat to specify multiple files)" }); + options.push_back({ "*", "-bf, --binary-file FNAME", "binary file containing the prompt (default: none)" }); + options.push_back({ "*", "-e, --escape", "process escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\) (default: %s)", params.escape ? "true" : "false" }); + options.push_back({ "*", " --no-escape", "do not process escape sequences" }); + options.push_back({ "main", "-ptc, --print-token-count N", "print token count every N tokens (default: %d)", params.n_print }); + options.push_back({ "main", " --prompt-cache FNAME", "file to cache prompt state for faster startup (default: none)" }); + options.push_back({ "main", " --prompt-cache-all", "if specified, saves user input and generations to cache as well\n" + "not supported with --interactive or other interactive options" }); + options.push_back({ "main", " --prompt-cache-ro", "if specified, uses the prompt cache but does not update it" }); + options.push_back({ "main", "-r, --reverse-prompt PROMPT", + "halt generation at PROMPT, return control in interactive mode\n" + "can be specified more than once for multiple prompts" }); + options.push_back({ "main", "-sp, --special", "special tokens output enabled (default: %s)", params.special ? "true" : "false" }); + options.push_back({ "main", "-cnv, --conversation", "run in conversation mode, does not print special tokens and suffix/prefix\n" + "if suffix/prefix are not specified, default chat template will be used\n" + "(default: %s)", params.conversation ? "true" : "false" }); + options.push_back({ "main infill", "-i, --interactive", "run in interactive mode (default: %s)", params.interactive ? "true" : "false" }); + options.push_back({ "main infill", "-if, --interactive-first", "run in interactive mode and wait for input right away (default: %s)", params.interactive_first ? "true" : "false" }); + options.push_back({ "main infill", "-mli, --multiline-input", "allows you to write or paste multiple lines without ending each in '\\'" }); + options.push_back({ "main infill", " --in-prefix-bos", "prefix BOS to user inputs, preceding the `--in-prefix` string" }); + options.push_back({ "main infill", " --in-prefix STRING", "string to prefix user inputs with (default: empty)" }); + options.push_back({ "main infill", " --in-suffix STRING", "string to suffix after user inputs with (default: empty)" }); + options.push_back({ "server infill", + " --spm-infill", "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" }); + + options.push_back({ "sampling" }); + options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n" + "(default: %s)", sampler_type_names.c_str() }); + options.push_back({ "*", " --sampling-seq SEQUENCE", + "simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() }); + options.push_back({ "*", " --ignore-eos", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" }); + options.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" }); + options.push_back({ "*", " --temp N", "temperature (default: %.1f)", (double)sparams.temp }); + options.push_back({ "*", " --top-k N", "top-k sampling (default: %d, 0 = disabled)", sparams.top_k }); + options.push_back({ "*", " --top-p N", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p }); + options.push_back({ "*", " --min-p N", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p }); + options.push_back({ "*", " --tfs N", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z }); + options.push_back({ "*", " --typical N", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typical_p }); + options.push_back({ "*", " --repeat-last-n N", "last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", sparams.penalty_last_n }); + options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat }); + options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present }); + options.push_back({ "*", " --frequency-penalty N", "repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_freq }); + options.push_back({ "*", " --dynatemp-range N", "dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)sparams.dynatemp_range }); + options.push_back({ "*", " --dynatemp-exp N", "dynamic temperature exponent (default: %.1f)", (double)sparams.dynatemp_exponent }); + options.push_back({ "*", " --mirostat N", "use Mirostat sampling.\n" + "Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n" + "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", sparams.mirostat }); + options.push_back({ "*", " --mirostat-lr N", "Mirostat learning rate, parameter eta (default: %.1f)", (double)sparams.mirostat_eta }); + options.push_back({ "*", " --mirostat-ent N", "Mirostat target entropy, parameter tau (default: %.1f)", (double)sparams.mirostat_tau }); + options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n" + "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n" + "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" }); + options.push_back({ "main", " --cfg-negative-prompt PROMPT", + "negative prompt to use for guidance (default: '%s')", sparams.cfg_negative_prompt.c_str() }); + options.push_back({ "main", " --cfg-negative-prompt-file FNAME", + "negative prompt file to use for guidance" }); + options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale }); + options.push_back({ "main", " --chat-template JINJA_TEMPLATE", + "set custom jinja chat template (default: template taken from model's metadata)\n" + "if suffix/prefix are specified, template will be disabled\n" + "only commonly used templates are accepted:\n" + "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); + options.push_back({ "grammar" }); + options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() }); + options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" }); + options.push_back({ "*", "-j, --json-schema SCHEMA", + "JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\n" + "For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead" }); + + options.push_back({ "embedding" }); + options.push_back({ "embedding", " --pooling {none,mean,cls,last}", + "pooling type for embeddings, use model default if unspecified" }); + options.push_back({ "embedding", " --attention {causal,non-causal}", + "attention type for embeddings, use model default if unspecified" }); + + options.push_back({ "context hacking" }); + options.push_back({ "*", " --rope-scaling {none,linear,yarn}", + "RoPE frequency scaling method, defaults to linear unless specified by the model" }); + options.push_back({ "*", " --rope-scale N", "RoPE context scaling factor, expands context by a factor of N" }); + options.push_back({ "*", " --rope-freq-base N", "RoPE base frequency, used by NTK-aware scaling (default: loaded from model)" }); + options.push_back({ "*", " --rope-freq-scale N", "RoPE frequency scaling factor, expands context by a factor of 1/N" }); + options.push_back({ "*", " --yarn-orig-ctx N", "YaRN: original context size of model (default: %d = model training context size)", params.yarn_orig_ctx }); + options.push_back({ "*", " --yarn-ext-factor N", "YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor }); + options.push_back({ "*", " --yarn-attn-factor N", "YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor }); + options.push_back({ "*", " --yarn-beta-slow N", "YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow }); + options.push_back({ "*", " --yarn-beta-fast N", "YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast }); + options.push_back({ "*", "-gan, --grp-attn-n N", "group-attention factor (default: %d)", params.grp_attn_n }); + options.push_back({ "*", "-gaw, --grp-attn-w N", "group-attention width (default: %.1f)", (double)params.grp_attn_w }); + options.push_back({ "*", "-dkvc, --dump-kv-cache", "verbose print of the KV cache" }); + options.push_back({ "*", "-nkvo, --no-kv-offload", "disable KV offload" }); + options.push_back({ "*", "-ctk, --cache-type-k TYPE", "KV cache data type for K (default: %s)", params.cache_type_k.c_str() }); + options.push_back({ "*", "-ctv, --cache-type-v TYPE", "KV cache data type for V (default: %s)", params.cache_type_v.c_str() }); + + options.push_back({ "perplexity" }); + options.push_back({ "perplexity", " --all-logits", "return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false" }); + options.push_back({ "perplexity", " --hellaswag", "compute HellaSwag score over random tasks from datafile supplied with -f" }); + options.push_back({ "perplexity", " --hellaswag-tasks N", "number of tasks to use when computing the HellaSwag score (default: %zu)", params.hellaswag_tasks }); + options.push_back({ "perplexity", " --winogrande", "compute Winogrande score over random tasks from datafile supplied with -f" }); + options.push_back({ "perplexity", " --winogrande-tasks N", "number of tasks to use when computing the Winogrande score (default: %zu)", params.winogrande_tasks }); + options.push_back({ "perplexity", " --multiple-choice", "compute multiple choice score over random tasks from datafile supplied with -f" }); + options.push_back({ "perplexity", " --multiple-choice-tasks N", + "number of tasks to use when computing the multiple choice score (default: %zu)", params.multiple_choice_tasks }); + options.push_back({ "perplexity", " --kl-divergence", "computes KL-divergence to logits provided via --kl-divergence-base" }); + options.push_back({ "perplexity", " --ppl-stride N", "stride for perplexity calculation (default: %d)", params.ppl_stride }); + options.push_back({ "perplexity", " --ppl-output-type {0,1}", + "output type for perplexity calculation (default: %d)", params.ppl_output_type }); + + options.push_back({ "parallel" }); + options.push_back({ "*", "-dt, --defrag-thold N", "KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold }); + options.push_back({ "*", "-np, --parallel N", "number of parallel sequences to decode (default: %d)", params.n_parallel }); + options.push_back({ "*", "-ns, --sequences N", "number of sequences to decode (default: %d)", params.n_sequences }); + options.push_back({ "*", "-cb, --cont-batching", "enable continuous batching (a.k.a dynamic batching) (default: %s)", params.cont_batching ? "enabled" : "disabled" }); + options.push_back({ "*", "-nocb, --no-cont-batching", "disable continuous batching" }); + + options.push_back({ "multi-modality" }); + options.push_back({ "*", " --mmproj FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md" }); + options.push_back({ "*", " --image FILE", "path to an image file. use with multimodal models. Specify multiple times for batching" }); + + options.push_back({ "backend" }); + options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" }); + if (llama_supports_mlock()) { - printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); + options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" }); } if (llama_supports_mmap()) { - printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); - } - printf(" --numa TYPE attempt optimizations that help on some NUMA systems\n"); - printf(" - distribute: spread execution evenly over all nodes\n"); - printf(" - isolate: only spawn threads on CPUs on the node that execution started on\n"); - printf(" - numactl: use the CPU map provided by numactl\n"); - printf(" if run without this previously, it is recommended to drop the system page cache before using this\n"); - printf(" see https://github.com/ggerganov/llama.cpp/issues/1437\n"); + options.push_back({ "*", " --no-mmap", "do not memory-map model (slower load but may reduce pageouts if not using mlock)" }); + } + options.push_back({ "*", " --numa TYPE", "attempt optimizations that help on some NUMA systems\n" + " - distribute: spread execution evenly over all nodes\n" + " - isolate: only spawn threads on CPUs on the node that execution started on\n" + " - numactl: use the CPU map provided by numactl\n" + "if run without this previously, it is recommended to drop the system page cache before using this\n" + "see https://github.com/ggerganov/llama.cpp/issues/1437" }); + if (llama_supports_gpu_offload()) { - printf(" -ngl N, --n-gpu-layers N\n"); - printf(" number of layers to store in VRAM\n"); - printf(" -ngld N, --n-gpu-layers-draft N\n"); - printf(" number of layers to store in VRAM for the draft model\n"); - printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n"); - printf(" how to split the model across multiple GPUs, one of:\n"); - printf(" - none: use one GPU only\n"); - printf(" - layer (default): split layers and KV across GPUs\n"); - printf(" - row: split rows across GPUs\n"); - printf(" -ts SPLIT, --tensor-split SPLIT\n"); - printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n"); - printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n"); - printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu); - } - printf(" --rpc SERVERS comma separated list of RPC servers\n"); - printf(" --verbose-prompt print a verbose prompt before generation (default: %s)\n", params.verbose_prompt ? "true" : "false"); - printf(" --no-display-prompt don't print prompt at generation (default: %s)\n", !params.display_prompt ? "true" : "false"); - printf(" -gan N, --grp-attn-n N\n"); - printf(" group-attention factor (default: %d)\n", params.grp_attn_n); - printf(" -gaw N, --grp-attn-w N\n"); - printf(" group-attention width (default: %.1f)\n", (double)params.grp_attn_w); - printf(" -dkvc, --dump-kv-cache\n"); - printf(" verbose print of the KV cache\n"); - printf(" -nkvo, --no-kv-offload\n"); - printf(" disable KV offload\n"); - printf(" -ctk TYPE, --cache-type-k TYPE\n"); - printf(" KV cache data type for K (default: %s)\n", params.cache_type_k.c_str()); - printf(" -ctv TYPE, --cache-type-v TYPE\n"); - printf(" KV cache data type for V (default: %s)\n", params.cache_type_v.c_str()); - printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); - printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); - printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n"); - printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); - printf(" --control-vector FNAME\n"); - printf(" add a control vector\n"); - printf(" --control-vector-scaled FNAME S\n"); - printf(" add a control vector with user defined scaling S\n"); - printf(" --control-vector-layer-range START END\n"); - printf(" layer range to apply the control vector(s) to, start and end inclusive\n"); - printf(" -m FNAME, --model FNAME\n"); - printf(" model path (default: models/$filename with filename from --hf-file or --model-url if set, otherwise %s)\n", DEFAULT_MODEL_PATH); - printf(" -md FNAME, --model-draft FNAME\n"); - printf(" draft model for speculative decoding (default: unused)\n"); - printf(" -mu MODEL_URL, --model-url MODEL_URL\n"); - printf(" model download url (default: unused)\n"); - printf(" -hfr REPO, --hf-repo REPO\n"); - printf(" Hugging Face model repository (default: unused)\n"); - printf(" -hff FILE, --hf-file FILE\n"); - printf(" Hugging Face model file (default: unused)\n"); - printf(" -ld LOGDIR, --logdir LOGDIR\n"); - printf(" path under which to save YAML logs (no logging if unset)\n"); - printf(" -lcs FNAME, --lookup-cache-static FNAME\n"); - printf(" path to static lookup cache to use for lookup decoding (not updated by generation)\n"); - printf(" -lcd FNAME, --lookup-cache-dynamic FNAME\n"); - printf(" path to dynamic lookup cache to use for lookup decoding (updated by generation)\n"); - printf(" --override-kv KEY=TYPE:VALUE\n"); - printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); - printf(" types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); - printf(" -ptc N, --print-token-count N\n"); - printf(" print token count every N tokens (default: %d)\n", params.n_print); - printf(" --check-tensors check model tensor data for invalid values\n"); - printf("\n"); + options.push_back({ "*", "-ngl, --gpu-layers N", + "number of layers to store in VRAM" }); + options.push_back({ "*", "-ngld, --gpu-layers-draft N", + "number of layers to store in VRAM for the draft model" }); + options.push_back({ "*", "-sm, --split-mode SPLIT_MODE", + "how to split the model across multiple GPUs, one of:\n" + " - none: use one GPU only\n" + " - layer (default): split layers and KV across GPUs\n" + " - row: split rows across GPUs" }); + options.push_back({ "*", "-ts, --tensor-split SPLIT", + "fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1" }); + options.push_back({ "*", "-mg, --main-gpu i", "the GPU to use for the model (with split-mode = none),\n" + "or for intermediate results and KV (with split-mode = row) (default: %d)", params.main_gpu }); + } + + options.push_back({ "model" }); + options.push_back({ "*", " --check-tensors", "check model tensor data for invalid values (default: %s)", params.check_tensors ? "true" : "false" }); + options.push_back({ "*", " --override-kv KEY=TYPE:VALUE", + "advanced option to override model metadata by key. may be specified multiple times.\n" + "types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false" }); + options.push_back({ "*", " --lora FNAME", "apply LoRA adapter (implies --no-mmap)" }); + options.push_back({ "*", " --lora-scaled FNAME S", "apply LoRA adapter with user defined scaling S (implies --no-mmap)" }); + options.push_back({ "*", " --lora-base FNAME", "optional model to use as a base for the layers modified by the LoRA adapter" }); + options.push_back({ "*", " --control-vector FNAME", "add a control vector\n" + "note: this argument can be repeated to add multiple control vectors" }); + options.push_back({ "*", " --control-vector-scaled FNAME SCALE", + "add a control vector with user defined scaling SCALE\n" + "note: this argument can be repeated to add multiple scaled control vectors" }); + options.push_back({ "*", " --control-vector-layer-range START END", + "layer range to apply the control vector(s) to, start and end inclusive" }); + options.push_back({ "*", "-m, --model FNAME", "model path (default: models/$filename with filename from --hf-file\n" + "or --model-url if set, otherwise %s)", DEFAULT_MODEL_PATH }); + options.push_back({ "*", "-md, --model-draft FNAME", "draft model for speculative decoding (default: unused)" }); + options.push_back({ "*", "-mu, --model-url MODEL_URL", "model download url (default: unused)" }); + options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" }); + options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" }); + options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" }); + + options.push_back({ "retrieval" }); + options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" }); + options.push_back({ "retrieval", " --chunk-size N", "minimum length of embedded text chunks (default: %d)", params.chunk_size }); + options.push_back({ "retrieval", " --chunk-separator STRING", + "separator between chunks (default: '%s')", params.chunk_separator.c_str() }); + + options.push_back({ "passkey" }); + options.push_back({ "passkey", " --junk N", "number of times to repeat the junk text (default: %d)", params.n_junk }); + options.push_back({ "passkey", " --pos N", "position of the passkey in the junk text (default: %d)", params.i_pos }); + + options.push_back({ "imatrix" }); + options.push_back({ "imatrix", "-o, --output FNAME", "output file (default: '%s')", params.out_file.c_str() }); + options.push_back({ "imatrix", " --output-frequency N", "output the imatrix every N iterations (default: %d)", params.n_out_freq }); + options.push_back({ "imatrix", " --save-frequency N", "save an imatrix copy every N iterations (default: %d)", params.n_save_freq }); + options.push_back({ "imatrix", " --process-output", "collect data for the output tensor (default: %s)", params.process_output ? "true" : "false" }); + options.push_back({ "imatrix", " --no-ppl", "do not compute perplexity (default: %s)", params.compute_ppl ? "true" : "false" }); + options.push_back({ "imatrix", " --chunk N", "start processing the input from chunk N (default: %d)", params.i_chunk }); + + options.push_back({ "bench" }); + options.push_back({ "bench", "-pps", "is the prompt shared across parallel sequences (default: %s)", params.is_pp_shared ? "true" : "false" }); + options.push_back({ "bench", "-npp n0,n1,...", "number of prompt tokens" }); + options.push_back({ "bench", "-ntg n0,n1,...", "number of text generation tokens" }); + options.push_back({ "bench", "-npl n0,n1,...", "number of parallel prompts" }); + + options.push_back({ "embedding" }); + options.push_back({ "embedding", " --embd-normalize", "normalisation for embendings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize }); + options.push_back({ "embedding", " --embd-output-format", "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix" }); + options.push_back({ "embedding", " --embd-separator", "separator of embendings (default \\n) for example \"<#sep#>\"" }); + + options.push_back({ "server" }); + options.push_back({ "server", " --host HOST", "ip address to listen (default: %s)", params.hostname.c_str() }); + options.push_back({ "server", " --port PORT", "port to listen (default: %d)", params.port }); + options.push_back({ "server", " --path PATH", "path to serve static files from (default: %s)", params.public_path.c_str() }); + options.push_back({ "server", " --embedding(s)", "enable embedding endpoint (default: %s)", params.embedding ? "enabled" : "disabled" }); + options.push_back({ "server", " --api-key KEY", "API key to use for authentication (default: none)" }); + options.push_back({ "server", " --api-key-file FNAME", "path to file containing API keys (default: none)" }); + options.push_back({ "server", " --ssl-key-file FNAME", "path to file a PEM-encoded SSL private key" }); + options.push_back({ "server", " --ssl-cert-file FNAME", "path to file a PEM-encoded SSL certificate" }); + options.push_back({ "server", " --timeout N", "server read/write timeout in seconds (default: %d)", params.timeout_read }); + options.push_back({ "server", " --threads-http N", "number of threads used to process HTTP requests (default: %d)", params.n_threads_http }); + options.push_back({ "server", " --system-prompt-file FNAME", + "set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications" }); + options.push_back({ "server", " --log-format {text,json}", + "log output format: json or text (default: json)" }); + options.push_back({ "server", " --metrics", "enable prometheus compatible metrics endpoint (default: %s)", params.endpoint_metrics ? "enabled" : "disabled" }); + options.push_back({ "server", " --no-slots", "disables slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled" }); + options.push_back({ "server", " --slot-save-path PATH", "path to save slot kv cache (default: disabled)" }); + options.push_back({ "server", " --chat-template JINJA_TEMPLATE", + "set custom jinja chat template (default: template taken from model's metadata)\n" + "only commonly used templates are accepted:\n" + "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); + options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY", + "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity }); + #ifndef LOG_DISABLE_LOGS - log_print_usage(); + options.push_back({ "logging" }); + options.push_back({ "*", " --simple-io", "use basic IO for better compatibility in subprocesses and limited consoles" }); + options.push_back({ "*", "-ld, --logdir LOGDIR", "path under which to save YAML logs (no logging if unset)" }); + options.push_back({ "logging", " --log-test", "Run simple logging test" }); + options.push_back({ "logging", " --log-disable", "Disable trace logs" }); + options.push_back({ "logging", " --log-enable", "Enable trace logs" }); + options.push_back({ "logging", " --log-file FNAME", "Specify a log filename (without extension)" }); + options.push_back({ "logging", " --log-new", "Create a separate new log file on start. " + "Each log file will have unique name: \"..log\"" }); + options.push_back({ "logging", " --log-append", "Don't truncate the old log file." }); #endif // LOG_DISABLE_LOGS + + options.push_back({ "cvector" }); + options.push_back({ "cvector", "-o, --output FNAME", "output file (default: '%s')", params.cvector_outfile.c_str() }); + options.push_back({ "cvector", " --positive-file FNAME", "positive prompts file, one prompt per line (default: '%s')", params.cvector_positive_file.c_str() }); + options.push_back({ "cvector", " --negative-file FNAME", "negative prompts file, one prompt per line (default: '%s')", params.cvector_negative_file.c_str() }); + options.push_back({ "cvector", " --pca-batch N", "batch size used for PCA. Larger batch runs faster, but uses more memory (default: %d)", params.n_pca_batch }); + options.push_back({ "cvector", " --pca-iter N", "number of iterations used for PCA (default: %d)", params.n_pca_iterations }); + options.push_back({ "cvector", " --method {pca,mean}", "dimensionality reduction method to be used (default: pca)" }); + + printf("usage: %s [options]\n", argv[0]); + + for (const auto & o : options) { + if (!o.grp.empty()) { + printf("\n%s:\n\n", o.grp.c_str()); + continue; + } + printf(" %-32s", o.args.c_str()); + if (o.args.length() > 30) { + printf("\n%34s", ""); + } + + const auto desc = o.desc; + size_t start = 0; + size_t end = desc.find('\n'); + while (end != std::string::npos) { + printf("%s\n%34s", desc.substr(start, end - start).c_str(), ""); + start = end + 1; + end = desc.find('\n', start); + } + + printf("%s\n", desc.substr(start).c_str()); + } + printf("\n"); } std::string gpt_params_get_system_info(const gpt_params & params) { @@ -1616,24 +1765,6 @@ std::string string_get_sortable_timestamp() { return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns); } -std::string string_random_prompt(std::mt19937 & rng) { - const int r = rng() % 10; - switch (r) { - case 0: return "So"; - case 1: return "Once upon a time"; - case 2: return "When"; - case 3: return "The"; - case 4: return "After"; - case 5: return "If"; - case 6: return "import"; - case 7: return "He"; - case 8: return "She"; - case 9: return "They"; - } - - LM_GGML_UNREACHABLE(); -} - void string_process_escapes(std::string & input) { std::size_t input_len = input.length(); std::size_t output_idx = 0; @@ -1893,6 +2024,16 @@ std::string fs_get_cache_directory() { return ensure_trailing_slash(cache_directory); } +std::string fs_get_cache_file(const std::string & filename) { + LM_GGML_ASSERT(filename.find(DIRECTORY_SEPARATOR) == std::string::npos); + std::string cache_directory = fs_get_cache_directory(); + const bool success = fs_create_directory_with_parents(cache_directory); + if (!success) { + throw std::runtime_error("failed to create cache directory: " + cache_directory); + } + return cache_directory + filename; +} + // // Model utils @@ -1904,9 +2045,9 @@ std::tuple llama_init_from_gpt_par llama_model * model = nullptr; if (!params.hf_repo.empty() && !params.hf_file.empty()) { - model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams); + model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams); } else if (!params.model_url.empty()) { - model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams); + model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams); } else { model = llama_load_model_from_file(params.model.c_str(), mparams); } @@ -1952,19 +2093,14 @@ std::tuple llama_init_from_gpt_par for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]); float lora_scale = std::get<1>(params.lora_adapter[i]); - int err = llama_model_apply_lora_from_file(model, - lora_adapter.c_str(), - lora_scale, - ((i > 0) || params.lora_base.empty()) - ? NULL - : params.lora_base.c_str(), - params.n_threads); - if (err != 0) { + auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str()); + if (adapter == nullptr) { fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); llama_free(lctx); llama_free_model(model); return std::make_tuple(nullptr, nullptr); } + llama_lora_adapter_set(lctx, adapter, lora_scale); } if (params.ignore_eos) { @@ -1974,7 +2110,24 @@ std::tuple llama_init_from_gpt_par if (params.warmup) { LOG("warming up the model with an empty run\n"); - std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; + std::vector tmp; + llama_token bos = llama_token_bos(model); + llama_token eos = llama_token_eos(model); + // some models (e.g. T5) don't have a BOS token + if (bos != -1) { + tmp.push_back(bos); + } + tmp.push_back(eos); + + if (llama_model_has_encoder(model)) { + llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); + llama_token decoder_start_token_id = llama_model_decoder_start_token(model); + if (decoder_start_token_id == -1) { + decoder_start_token_id = bos; + } + tmp.clear(); + tmp.push_back(decoder_start_token_id); + } llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_kv_cache_clear(lctx); llama_synchronize(lctx); @@ -2057,6 +2210,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.pooling_type = params.pooling_type; + cparams.attention_type = params.attention_type; cparams.defrag_thold = params.defrag_thold; cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; @@ -2076,7 +2230,7 @@ static bool starts_with(const std::string & str, const std::string & prefix) { return str.rfind(prefix, 0) == 0; } -static bool llama_download_file(const std::string & url, const std::string & path) { +static bool llama_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { // Initialize libcurl std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); @@ -2091,6 +2245,15 @@ static bool llama_download_file(const std::string & url, const std::string & pat curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); + // Check if hf-token or bearer-token was specified + if (!hf_token.empty()) { + std::string auth_header = "Authorization: Bearer "; + auth_header += hf_token.c_str(); + struct curl_slist *http_headers = NULL; + http_headers = curl_slist_append(http_headers, auth_header.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers); + } + #if defined(_WIN32) // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of // operating system. Currently implemented under MS-Windows. @@ -2207,7 +2370,14 @@ static bool llama_download_file(const std::string & url, const std::string & pat } // Set the output file - std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb"), fclose); + + struct FILE_deleter { + void operator()(FILE * f) const { + fclose(f); + } + }; + + std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb")); if (!outfile) { fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path.c_str()); return false; @@ -2279,6 +2449,7 @@ static bool llama_download_file(const std::string & url, const std::string & pat struct llama_model * llama_load_model_from_url( const char * model_url, const char * path_model, + const char * hf_token, const struct llama_model_params & params) { // Basic validation of the model_url if (!model_url || strlen(model_url) == 0) { @@ -2286,7 +2457,7 @@ struct llama_model * llama_load_model_from_url( return NULL; } - if (!llama_download_file(model_url, path_model)) { + if (!llama_download_file(model_url, path_model, hf_token)) { return NULL; } @@ -2334,14 +2505,14 @@ struct llama_model * llama_load_model_from_url( // Prepare download in parallel std::vector> futures_download; for (int idx = 1; idx < n_split; idx++) { - futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split](int download_idx) -> bool { + futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool { char split_path[PATH_MAX] = {0}; llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split); char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0}; llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split); - return llama_download_file(split_url, split_path); + return llama_download_file(split_url, split_path, hf_token); }, idx)); } @@ -2360,6 +2531,7 @@ struct llama_model * llama_load_model_from_hf( const char * repo, const char * model, const char * path_model, + const char * hf_token, const struct llama_model_params & params) { // construct hugging face model url: // @@ -2375,7 +2547,7 @@ struct llama_model * llama_load_model_from_hf( model_url += "/resolve/main/"; model_url += model; - return llama_load_model_from_url(model_url.c_str(), path_model, params); + return llama_load_model_from_url(model_url.c_str(), path_model, hf_token, params); } #else @@ -2383,6 +2555,7 @@ struct llama_model * llama_load_model_from_hf( struct llama_model * llama_load_model_from_url( const char * /*model_url*/, const char * /*path_model*/, + const char * /*hf_token*/, const struct llama_model_params & /*params*/) { fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); return nullptr; @@ -2392,6 +2565,7 @@ struct llama_model * llama_load_model_from_hf( const char * /*repo*/, const char * /*model*/, const char * /*path_model*/, + const char * /*hf_token*/, const struct llama_model_params & /*params*/) { fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); return nullptr; @@ -2456,57 +2630,126 @@ std::vector llama_tokenize( } std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { - std::vector result(8, 0); - const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); - if (n_tokens < 0) { - result.resize(-n_tokens); - int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); - LM_GGML_ASSERT(check == -n_tokens); - } else { - result.resize(n_tokens); + std::string piece; + piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' + const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + if (n_chars < 0) { + piece.resize(-n_chars); + int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + LM_GGML_ASSERT(check == -n_chars); + } + else { + piece.resize(n_chars); } - return std::string(result.data(), result.size()); + return piece; } -std::string llama_detokenize_spm(llama_context * ctx, const std::vector & tokens) { - const llama_token bos_id = llama_token_bos(llama_get_model(ctx)); - - std::string piece; - std::string result; +std::string llama_detokenize(llama_context * ctx, const std::vector & tokens, bool special) { + std::string text; + text.resize(std::max(text.capacity(), tokens.size())); + int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + if (n_chars < 0) { + text.resize(-n_chars); + n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + LM_GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization + } - for (size_t i = 0; i < tokens.size(); ++i) { - piece = llama_token_to_piece(ctx, tokens[i]); + text.resize(n_chars); - // remove the leading space of the first non-BOS token - if (((tokens[0] == bos_id && i == 1) || (tokens[0] != bos_id && i == 0)) && piece[0] == ' ') { - piece = piece.substr(1); - } + // NOTE: the original tokenizer decodes bytes after collecting the pieces. + return text; +} - result += piece; - } +bool llama_should_add_bos_token(const llama_model * model) { + const int add_bos = llama_add_bos_token(model); - return result; + return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); } -std::string llama_detokenize_bpe(llama_context * ctx, const std::vector & tokens) { - std::string piece; - std::string result; +// +// Chat template utils +// + +bool llama_chat_verify_template(const std::string & tmpl) { + llama_chat_message chat[] = {{"user", "test"}}; + int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); + return res >= 0; +} - for (size_t i = 0; i < tokens.size(); ++i) { - piece = llama_token_to_piece(ctx, tokens[i]); +std::string llama_chat_apply_template(const struct llama_model * model, + const std::string & tmpl, + const std::vector & msgs, + bool add_ass) { + int alloc_size = 0; + bool fallback = false; // indicate if we must fallback to default chatml + std::vector chat; + for (auto & msg : msgs) { + chat.push_back({msg.role.c_str(), msg.content.c_str()}); + alloc_size += (msg.role.size() + msg.content.size()) * 1.25; + } + + const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); + std::vector buf(alloc_size); + + // run the first time to get the total output length + int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + + // error: chat template is not supported + if (res < 0) { + if (ptr_tmpl != nullptr) { + // if the custom "tmpl" is not supported, we throw an error + // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() + throw std::runtime_error("this custom template is not supported"); + } else { + // If the built-in template is not supported, we default to chatml + res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + fallback = true; + } + } - result += piece; + // if it turns out that our buffer is too small, we resize it + if ((size_t) res > buf.size()) { + buf.resize(res); + res = llama_chat_apply_template( + fallback ? nullptr : model, + fallback ? "chatml" : ptr_tmpl, + chat.data(), chat.size(), add_ass, buf.data(), buf.size()); } - // NOTE: the original tokenizer decodes bytes after collecting the pieces. - return result; + std::string formatted_chat(buf.data(), res); + return formatted_chat; } -bool llama_should_add_bos_token(const llama_model * model) { - const int add_bos = llama_add_bos_token(model); +std::string llama_chat_format_single(const struct llama_model * model, + const std::string & tmpl, + const std::vector & past_msg, + const llama_chat_msg & new_msg, + bool add_ass) { + std::ostringstream ss; + auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false); + std::vector chat_new(past_msg); + // if the past_msg ends with a newline, we must preserve it in the formatted version + if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { + ss << "\n"; + }; + // format chat with new_msg + chat_new.push_back(new_msg); + auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass); + // get the diff part + ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); + return ss.str(); +} - return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); +std::string llama_chat_format_example(const struct llama_model * model, + const std::string & tmpl) { + std::vector msgs = { + {"system", "You are a helpful assistant"}, + {"user", "Hello"}, + {"assistant", "Hi there"}, + {"user", "How are you?"}, + }; + return llama_chat_apply_template(model, tmpl, msgs, true); } // @@ -2588,14 +2831,34 @@ void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_siz // Embedding utils // -void llama_embd_normalize(const float * inp, float * out, int n) { +void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm) { double sum = 0.0; - for (int i = 0; i < n; i++) { - sum += inp[i] * inp[i]; + + switch (embd_norm) { + case -1: // no normalisation + sum = 1.0; + break; + case 0: // max absolute + for (int i = 0; i < n; i++) { + if (sum < std::abs(inp[i])) sum = std::abs(inp[i]); + } + sum /= 32760.0; // make an int16 range + break; + case 2: // euclidean + for (int i = 0; i < n; i++) { + sum += inp[i] * inp[i]; + } + sum = std::sqrt(sum); + break; + default: // p-norm (euclidean is p-norm p=2) + for (int i = 0; i < n; i++) { + sum += std::pow(std::abs(inp[i]), embd_norm); + } + sum = std::pow(sum, 1.0 / embd_norm); + break; } - sum = sqrt(sum); - const float norm = sum > 0.0 ? 1.0f / sum : 0.0f; + const float norm = sum > 0.0 ? 1.0 / sum : 0.0f; for (int i = 0; i < n; i++) { out[i] = inp[i] * norm; @@ -2613,6 +2876,14 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n) sum2 += embd2[i] * embd2[i]; } + // Handle the case where one or both vectors are zero vectors + if (sum1 == 0.0 || sum2 == 0.0) { + if (sum1 == 0.0 && sum2 == 0.0) { + return 1.0f; // two zero vectors are similar + } + return 0.0f; + } + return sum / (sqrt(sum1) * sqrt(sum2)); } @@ -2621,125 +2892,87 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n) // static llama_control_vector_data llama_control_vector_load_one(const llama_control_vector_load_info & load_info) { - int32_t n_tensors; - - size_t n_bytes = 0; - - uint32_t max_direction_layer = 0; - llama_control_vector_data result = { -1, {} }; - // calculate size of ctx needed for tensors, ensure tensors are f32, and find max layer - { - struct lm_ggml_init_params meta_params = { - /* .mem_size = */ lm_ggml_tensor_overhead() * 128 + lm_ggml_graph_overhead(), - /* .mem_buffer = */ nullptr, - /* .no_alloc = */ true, - }; - lm_ggml_context * meta_ctx = lm_ggml_init(meta_params); - struct lm_gguf_init_params meta_lm_gguf_params = { - /* .no_alloc = */ true, - /* .ctx = */ &meta_ctx, - }; - struct lm_gguf_context * meta_ctx_gguf = lm_gguf_init_from_file(load_info.fname.c_str(), meta_lm_gguf_params); - if (!meta_ctx_gguf) { - fprintf(stderr, "%s: failed to load control vector from %s\n", __func__, load_info.fname.c_str()); - lm_ggml_free(meta_ctx); - return result; - } - - n_tensors = lm_gguf_get_n_tensors(meta_ctx_gguf); - for (int i = 0; i < n_tensors; i++) { - std::string name = lm_gguf_get_tensor_name(meta_ctx_gguf, i); - - // split on '.' - size_t dotpos = name.find('.'); - if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") { - try { - uint32_t layer = std::stoi(name.substr(dotpos + 1)); - if (layer == 0) { - fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str()); - lm_ggml_free(meta_ctx); - lm_gguf_free(meta_ctx_gguf); - return result; - } - if (layer > max_direction_layer) { - max_direction_layer = layer; - } - } catch (...) { - fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str()); - lm_ggml_free(meta_ctx); - lm_gguf_free(meta_ctx_gguf); - return result; - } - } - - struct lm_ggml_tensor * tensor_meta = lm_ggml_get_tensor(meta_ctx, name.c_str()); - if (tensor_meta->type != LM_GGML_TYPE_F32 || lm_ggml_n_dims(tensor_meta) != 1) { - fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str()); - lm_ggml_free(meta_ctx); - lm_gguf_free(meta_ctx_gguf); - return result; - } - if (result.n_embd == -1) { - result.n_embd = lm_ggml_nelements(tensor_meta); - } else if (lm_ggml_nelements(tensor_meta) != result.n_embd) { - fprintf(stderr, "%s: direction tensor sizes mismatched in %s\n", __func__, load_info.fname.c_str()); - lm_ggml_free(meta_ctx); - lm_gguf_free(meta_ctx_gguf); - return result; - } - n_bytes += lm_ggml_nbytes(tensor_meta); - } - lm_ggml_free(meta_ctx); - lm_gguf_free(meta_ctx_gguf); + lm_ggml_context * ctx = nullptr; + struct lm_gguf_init_params meta_lm_gguf_params = { + /* .no_alloc = */ false, + /* .ctx = */ &ctx, + }; + struct lm_gguf_context * ctx_gguf = lm_gguf_init_from_file(load_info.fname.c_str(), meta_lm_gguf_params); + if (!ctx_gguf) { + fprintf(stderr, "%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str()); + return result; } + int32_t n_tensors = lm_gguf_get_n_tensors(ctx_gguf); if (n_tensors == 0) { fprintf(stderr, "%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str()); - return result; } - // load and scale tensors into final control vector context - struct lm_ggml_init_params lm_ggml_params = { - /* .mem_size = */ lm_ggml_tensor_overhead() * n_tensors + n_bytes, - /* .mem_buffer = */ nullptr, - /* .no_alloc = */ false, - }; - struct lm_ggml_context * ctx = lm_ggml_init(lm_ggml_params); + for (int i = 0; i < n_tensors; i++) { + std::string name = lm_gguf_get_tensor_name(ctx_gguf, i); - struct lm_gguf_init_params params = { - /*.no_alloc = */ false, - /*.ctx = */ &ctx, - }; - struct lm_gguf_context * ctx_gguf = lm_gguf_init_from_file(load_info.fname.c_str(), params); - if (!ctx_gguf) { - fprintf(stderr, "%s: failed to load control vector from %s\n", __func__, load_info.fname.c_str()); - lm_ggml_free(ctx); - return result; - } + int layer_idx = -1; + + // split on '.' + size_t dotpos = name.find('.'); + if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") { + try { + layer_idx = std::stoi(name.substr(dotpos + 1)); + } catch (...) { + layer_idx = -1; + } + } + if (layer_idx < 0) { + fprintf(stderr, "%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str()); + result.n_embd = -1; + break; + } else if (layer_idx == 0) { + fprintf(stderr, "%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str()); + result.n_embd = -1; + break; + } - // do not store data for layer 0 (it's not used) - result.data.resize(result.n_embd * max_direction_layer); + struct lm_ggml_tensor * tensor = lm_ggml_get_tensor(ctx, name.c_str()); + if (tensor->type != LM_GGML_TYPE_F32) { + fprintf(stderr, "%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str()); + result.n_embd = -1; + break; + } + if (lm_ggml_n_dims(tensor) != 1) { + fprintf(stderr, "%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str()); + result.n_embd = -1; + break; + } - for (uint32_t il = 1; il <= max_direction_layer; il++) { - const std::string name = "direction." + std::to_string(il); - const lm_ggml_tensor * tensor = lm_ggml_get_tensor(ctx, name.c_str()); + if (result.n_embd == -1) { + result.n_embd = lm_ggml_nelements(tensor); + } else if (lm_ggml_nelements(tensor) != result.n_embd) { + fprintf(stderr, "%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str()); + result.n_embd = -1; + break; + } - float * dst = result.data.data() + result.n_embd * (il - 1); + // extend if necessary - do not store data for layer 0 (it's not used) + result.data.resize(std::max(result.data.size(), static_cast(result.n_embd * layer_idx)), 0.0f); - if (tensor) { - const float * src = (const float *) tensor->data; - for (int j = 0; j < result.n_embd; j++) { - dst[j] = src[j] * load_info.strength; - } - } else { - for (int j = 0; j < result.n_embd; j++) { - dst[j] = 0.0f; - } + const float * src = (const float *) tensor->data; + float * dst = result.data.data() + result.n_embd * (layer_idx - 1); // layer 1 at [0] + for (int j = 0; j < result.n_embd; j++) { + dst[j] += src[j] * load_info.strength; // allows multiple directions for same layer in same file } + } + if (result.n_embd == -1) { + fprintf(stderr, "%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str()); + result.data.clear(); + } + + lm_gguf_free(ctx_gguf); + lm_ggml_free(ctx); + return result; } @@ -2750,16 +2983,19 @@ llama_control_vector_data llama_control_vector_load(const std::vector=32 to use BLAS) - int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 5; // number of tokens to draft during speculative decoding - int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) - int32_t n_parallel = 1; // number of parallel sequences to decode - int32_t n_sequences = 1; // number of sequences to decode - float p_split = 0.1f; // speculative decoding split probability - int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) - int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) - llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs - int32_t n_beams = 0; // if non-zero then use beam search of given width. - int32_t grp_attn_n = 1; // group-attention factor - int32_t grp_attn_w = 512; // group-attention width - int32_t n_print = -1; // print token count every n tokens (-1 = disabled) - float rope_freq_base = 0.0f; // RoPE base frequency - float rope_freq_scale = 0.0f; // RoPE frequency scaling factor + int32_t n_threads_draft = -1; + int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) + int32_t n_threads_batch_draft = -1; + int32_t n_predict = -1; // new tokens to predict + int32_t n_ctx = 0; // context size + int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_draft = 5; // number of tokens to draft during speculative decoding + int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) + int32_t n_parallel = 1; // number of parallel sequences to decode + int32_t n_sequences = 1; // number of sequences to decode + float p_split = 0.1f; // speculative decoding split probability + int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) + int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + int32_t grp_attn_n = 1; // group-attention factor + int32_t grp_attn_w = 512; // group-attention width + int32_t n_print = -1; // print token count every n tokens (-1 = disabled) + float rope_freq_base = 0.0f; // RoPE base frequency + float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor - float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor + float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor float yarn_beta_fast = 32.0f; // YaRN low correction dim - float yarn_beta_slow = 1.0f; // YaRN high correction dim - int32_t yarn_orig_ctx = 0; // YaRN original context length + float yarn_beta_slow = 1.0f; // YaRN high correction dim + int32_t yarn_orig_ctx = 0; // YaRN original context length float defrag_thold = -1.0f; // KV cache defragmentation threshold - std::string rpc_servers = ""; // comma separated list of RPC servers lm_ggml_backend_sched_eval_callback cb_eval = nullptr; void * cb_eval_user_data = nullptr; lm_ggml_numa_strategy numa = LM_GGML_NUMA_STRATEGY_DISABLED; + enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings + enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings // // sampling parameters struct llama_sampling_params sparams; - std::string model = ""; // model path - std::string model_draft = ""; // draft model for speculative decoding + std::string model = ""; // model path + std::string model_draft = ""; // draft model for speculative decoding std::string model_alias = "unknown"; // model alias - std::string model_url = ""; // model url to download - std::string hf_repo = ""; // HF repo - std::string hf_file = ""; // HF file + std::string model_url = ""; // model url to download + std::string hf_token = ""; // HF token + std::string hf_repo = ""; // HF repo + std::string hf_file = ""; // HF file std::string prompt = ""; - std::string prompt_file = ""; // store the external prompt file name - std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state - std::string input_prefix = ""; // string to prefix user inputs with - std::string input_suffix = ""; // string to suffix user inputs with - std::vector antiprompt; // string upon seeing which more user input is prompted - std::string logdir = ""; // directory in which to save YAML log files + std::string prompt_file = ""; // store the external prompt file name + std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state + std::string input_prefix = ""; // string to prefix user inputs with + std::string input_suffix = ""; // string to suffix user inputs with + std::string logdir = ""; // directory in which to save YAML log files std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding - std::string logits_file = ""; // file for saving *all* logits + std::string logits_file = ""; // file for saving *all* logits + std::string rpc_servers = ""; // comma separated list of RPC servers + std::vector in_files; // all input files + std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; // TODO: avoid tuple, use struct @@ -135,37 +143,35 @@ struct gpt_params { std::vector control_vectors; // control vector with user defined scale + int32_t verbosity = 0; int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector - int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. - int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line - // (which is more convenient to use for plotting) - // - bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt - size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score + int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. + int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line + // (which is more convenient to use for plotting) + // + bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt + size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score - bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt - size_t winogrande_tasks= 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed + bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt + size_t winogrande_tasks = 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed - bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt - size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed + bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt + size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed - bool kl_divergence = false; // compute KL divergence + bool kl_divergence = false; // compute KL divergence - bool random_prompt = false; // do not randomize prompt if none provided + bool usage = false; // print usage bool use_color = false; // use color to distinguish generations and inputs - bool interactive = false; // interactive mode - bool interactive_specials = false; // whether to allow special tokens from user, during interactive mode bool special = false; // enable special token output + bool interactive = false; // interactive mode + bool interactive_first = false; // wait for user input immediately bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix) - bool chatml = false; // chatml mode (used for models trained on chatml syntax) bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it - bool embedding = false; // get only sentence embedding - bool escape = false; // escape "\n", "\r", "\t", "\'", "\"", and "\\" - bool interactive_first = false; // wait for user input immediately + bool escape = true; // escape "\n", "\r", "\t", "\'", "\"", and "\\" bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly @@ -173,7 +179,6 @@ struct gpt_params { bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens - bool instruct = false; // instruction mode (used for Alpaca models) bool logits_all = false; // return logits for all tokens in the batch bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory @@ -191,8 +196,79 @@ struct gpt_params { // multimodal models (see examples/llava) std::string mmproj = ""; // path to multimodal projector std::vector image; // path to image file(s) + + // embedding + bool embedding = false; // get only sentence embedding + int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) + std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix + std::string embd_sep = "\n"; // separator of embendings + + // server params + int32_t port = 8080; // server listens on this network port + int32_t timeout_read = 600; // http read timeout in seconds + int32_t timeout_write = timeout_read; // http write timeout in seconds + int32_t n_threads_http = -1; // number of threads to process HTTP requests + + std::string hostname = "127.0.0.1"; + std::string public_path = ""; + std::string chat_template = ""; + std::string system_prompt = ""; + bool enable_chat_template = true; + + std::vector api_keys; + + std::string ssl_file_key = ""; + std::string ssl_file_cert = ""; + + bool endpoint_slots = true; + bool endpoint_metrics = false; + + bool log_json = false; + + std::string slot_save_path; + + float slot_prompt_similarity = 0.5f; + + // batched-bench params + bool is_pp_shared = false; + + std::vector n_pp; + std::vector n_tg; + std::vector n_pl; + + // retrieval params + std::vector context_files; // context files to embed + + int32_t chunk_size = 64; // chunk size for context embedding + + std::string chunk_separator = "\n"; // chunk separator for context embedding + + // passkey params + int32_t n_junk = 250; // number of times to repeat the junk text + int32_t i_pos = -1; // position of the passkey in the junk text + + // imatrix params + std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file + + int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations + int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations + int32_t i_chunk = 0; // start processing from this chunk + + bool process_output = false; // collect data for the output tensor + bool compute_ppl = true; // whether to compute perplexity + + // cvector-generator params + int n_pca_batch = 100; + int n_pca_iterations = 1000; + dimre_method cvector_dimre_method = DIMRE_METHOD_PCA; + std::string cvector_outfile = "control_vector.gguf"; + std::string cvector_positive_file = "examples/cvector-generator/positive.txt"; + std::string cvector_negative_file = "examples/cvector-generator/negative.txt"; + + bool spm_infill = false; // suffix/prefix/middle pattern for infill }; +void gpt_params_handle_hf_token(gpt_params & params); void gpt_params_handle_model_default(gpt_params & params); bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params); @@ -210,7 +286,20 @@ std::vector string_split(std::string input, char separator); std::string string_strip(const std::string & str); std::string string_get_sortable_timestamp(); -std::string string_random_prompt(std::mt19937 & rng); + +template +static std::vector string_split(const std::string & str, char delim) { + std::vector values; + std::istringstream str_stream(str); + std::string token; + while (std::getline(str_stream, token, delim)) { + T value; + std::istringstream token_stream(token); + token_stream >> value; + values.push_back(value); + } + return values; +} bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input); @@ -223,6 +312,7 @@ bool fs_validate_filename(const std::string & filename); bool fs_create_directory_with_parents(const std::string & path); std::string fs_get_cache_directory(); +std::string fs_get_cache_file(const std::string & filename); // // Model utils @@ -234,8 +324,8 @@ std::tuple llama_init_from_gpt_par struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); -struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params); -struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params); +struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params); +struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params); // Batch utils @@ -273,26 +363,50 @@ std::string llama_token_to_piece( llama_token token, bool special = true); -// TODO: these should be moved in llama.h C-style API under single `llama_detokenize` function -// that takes into account the tokenizer type and decides how to handle the leading space -// -// detokenizes a vector of tokens into a string -// should work similar to Python's `tokenizer.decode` -// removes the leading space from the first non-BOS token -std::string llama_detokenize_spm( - llama_context * ctx, - const std::vector & tokens); - // detokenizes a vector of tokens into a string // should work similar to Python's `tokenizer.decode` -std::string llama_detokenize_bpe( +// optionally renders special/control tokens +std::string llama_detokenize( llama_context * ctx, - const std::vector & tokens); + const std::vector & tokens, + bool special = true); // Uses the value from the model metadata if possible, otherwise // defaults to true when model type is SPM, otherwise false. bool llama_should_add_bos_token(const llama_model * model); +// +// Chat template utils +// + +// same with llama_chat_message, but uses std::string +struct llama_chat_msg { + std::string role; + std::string content; +}; + +// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid +bool llama_chat_verify_template(const std::string & tmpl); + +// CPP wrapper for llama_chat_apply_template +// If the built-in template is not supported, we default to chatml +// If the custom "tmpl" is not supported, we throw an error +std::string llama_chat_apply_template(const struct llama_model * model, + const std::string & tmpl, + const std::vector & chat, + bool add_ass); + +// Format single message, while taking into account the position of that message in chat history +std::string llama_chat_format_single(const struct llama_model * model, + const std::string & tmpl, + const std::vector & past_msg, + const llama_chat_msg & new_msg, + bool add_ass); + +// Returns an example of formatted chat +std::string llama_chat_format_example(const struct llama_model * model, + const std::string & tmpl); + // // KV cache utils // @@ -307,7 +421,7 @@ void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_siz // Embedding utils // -void llama_embd_normalize(const float * inp, float * out, int n); +void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2); float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n); @@ -351,4 +465,3 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha void yaml_dump_non_result_info( FILE * stream, const gpt_params & params, const llama_context * lctx, const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); - diff --git a/cpp/ggml-aarch64.c b/cpp/ggml-aarch64.c new file mode 100644 index 00000000..12668bcf --- /dev/null +++ b/cpp/ggml-aarch64.c @@ -0,0 +1,2193 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. +#define LM_GGML_COMMON_IMPL_C +#include "ggml-common.h" + +#include "ggml-quants.h" +#include "ggml-impl.h" + +#include +#include +#include +#include +#include // for qsort +#include // for LM_GGML_ASSERT + +#include "ggml-aarch64.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Woverlength-strings" +#endif + +#define UNUSED LM_GGML_UNUSED + +// Functions to create the interleaved data layout formats + +// interleave 4 block_q4_0s in blocks of blck_size_interleave +// returns an interleaved block_q4_0x4 +// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks +// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave +// +// - in : an array of block_q4_0 pointers +// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of +// blck_size_interleave bytes +// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes +// from bias offset form to pure sign form (this saves subtract +// operations durin unpacking) +// +static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { + block_q4_0x4 out; + + for (int i = 0; i < 4; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < QK4_0 * 2; i++) { + int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (i % blck_size_interleave); + + out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; + } + + return out; +} + +// interleave 8 block_q4_0s in blocks of blck_size_interleave +// returns an interleaved block_q4_0x8 +// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks +// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave +static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { + block_q4_0x8 out; + + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < QK4_0 * 4; i++) { + int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave; + int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave; + src_offset += (i % blck_size_interleave); + + out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; + } + + return out; +} + +void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; + +#if defined(__ARM_NEON) + float32x4_t srcv[4][8]; + float id[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = LM_GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 8; j++) { + float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]); + int32x4_t vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[1][j], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[2][j], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[3][j], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3); + } + } +#else + // scalar + const int blck_size_interleave = 4; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = LM_GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +#endif +} + +void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; + +#if defined(__ARM_NEON) + float32x4_t srcv[4][8]; + float id[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = LM_GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 4; j++) { + float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]); + int32x4_t vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[1][2 * j], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[2][2 * j], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[3][2 * j], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3); + } + } +#else + // scalar + const int blck_size_interleave = 8; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = LM_GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +#endif +} + +void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { + assert(nrow == 4); + UNUSED(nrow); + if (blck_size_interleave == 4) { + quantize_q8_0_4x4(x, vy, n_per_row); + } else if (blck_size_interleave == 8) { + quantize_q8_0_4x8(x, vy, n_per_row); + } else { + assert(false); + } +} + +static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) { + assert(n_per_row % QK4_0 == 0); + const int nb = n_per_row / QK4_0; + + void * out_ptr = NULL; + if (nrows_interleaved == 8) { + out_ptr = (block_q4_0x8 *) dst; + } + else if (nrows_interleaved == 4) { + out_ptr = (block_q4_0x4 *) dst; + } + assert(nrows_interleaved <= 8); + block_q4_0 dst_tmp[8]; + + for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) { + + for (int64_t x = 0; x < nb; x++) { + + for (int i = 0; i < nrows_interleaved; i++ ) { + quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0); + } + + if (nrows_interleaved == 8) { + *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88); + out_ptr = (block_q4_0x8 *) out_ptr + 1; + } + else if (nrows_interleaved == 4) { + *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88); + out_ptr = (block_q4_0x4 *) out_ptr + 1; + } + } + } + + return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0)); +} + +size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4); + } + else { + assert(false); + return 0; + } +} + +size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8); + } + else { + assert(false); + return 0; + } +} + +size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8); + } + else { + assert(false); + return 0; + } +} + +void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) + if (svcntw() == 8) { + LM_GGML_ASSERT(!(lm_ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + LM_GGML_ASSERT(!(lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) && + "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance"); +#elif defined(__ARM_NEON) && defined(__aarch64__) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "movi v31.16b, #0x4\n" + "movi v30.16b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x8\n" + "1:" // Column loop + "add x22, %x[a_ptr], #0x2\n" + "movi v29.16b, #0x0\n" + "mov x21, %x[nb]\n" + "2:" // Block loop + "ldr q28, [%x[b_ptr], #0x0]\n" + "ldr q27, [x22, #0x0]\n" + "movi v26.4s, #0x0\n" + "sub x20, x22, #0x2\n" + "ldr q25, [x22, #0x10]\n" + "ldr q24, [%x[b_ptr], #0x10]\n" + "sub x21, x21, #0x1\n" + "add x22, x22, #0x22\n" + "ldr q23, [%x[b_ptr], #0x20]\n" + "ldr q22, [%x[b_ptr], #0x30]\n" + "ld1r { v21.8h }, [x20]\n" + "ldr q20, [%x[b_ptr], #-0x8]\n" + "sshl v16.16b, v28.16b, v31.16b\n" + "and v28.16b, v28.16b, v30.16b\n" + "sshl v19.16b, v24.16b, v31.16b\n" + "and v24.16b, v24.16b, v30.16b\n" + "add %x[b_ptr], %x[b_ptr], #0x48\n" + "sshl v18.16b, v23.16b, v31.16b\n" + "and v23.16b, v23.16b, v30.16b\n" + ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n" + "sshl v17.16b, v22.16b, v31.16b\n" + "and v22.16b, v22.16b, v30.16b\n" + "fcvtl v21.4s, v21.4h\n" + "fcvtl v16.4s, v20.4h\n" + ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n" + "fmul v16.4s, v16.4s, v21.4s\n" + ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n" + ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n" + ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n" + ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n" + ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n" + ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "fmla v29.4s, v26.4s, v16.4s\n" + "cbnz x21, 2b\n" + "sub %x[nc], %x[nc], #0x4\n" + "str q29, [%x[res_ptr], #0x0]\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22" + ); +#else + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +#endif +} + +void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) + if (svcntw() == 8) { + LM_GGML_ASSERT(!(lm_ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "movi v2.16b, #0x4\n" + "movi v1.16b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x8\n" + "1:" // Column loop + "add x23, %x[a_ptr], #0x2\n" + "movi v0.16b, #0x0\n" + "mov x22, %x[nb]\n" + "2:" // Block loop + "ldr q31, [%x[b_ptr], #0x0]\n" + "ldr q30, [%x[b_ptr], #0x10]\n" + "mov x21, x23\n" + "movi v29.4s, #0x0\n" + "ldr q28, [%x[b_ptr], #0x20]\n" + "ldr q27, [%x[b_ptr], #0x30]\n" + "movi v26.4s, #0x0\n" + "sub x20, x23, #0x2\n" + "ld1r { v25.8h }, [x20]\n" + "ldr q24, [%x[b_ptr], #-0x8]\n" + "sub x22, x22, #0x1\n" + "add x23, x23, #0x22\n" + "ld1r { v23.2d }, [x21], #0x8\n" + "sshl v22.16b, v31.16b, v2.16b\n" + "sshl v16.16b, v30.16b, v2.16b\n" + "add %x[b_ptr], %x[b_ptr], #0x48\n" + "ld1r { v21.2d }, [x21], #0x8\n" + "sshl v20.16b, v28.16b, v2.16b\n" + "sshl v19.16b, v27.16b, v2.16b\n" + "ld1r { v18.2d }, [x21], #0x8\n" + "ld1r { v17.2d }, [x21], #0x8\n" + "and v31.16b, v31.16b, v1.16b\n" + "and v30.16b, v30.16b, v1.16b\n" + ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n" + ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n" + "and v28.16b, v28.16b, v1.16b\n" + "and v27.16b, v27.16b, v1.16b\n" + "fcvtl v25.4s, v25.4h\n" + "fcvtl v16.4s, v24.4h\n" + ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n" + ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n" + "fmul v16.4s, v16.4s, v25.4s\n" + ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n" + ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n" + ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n" + ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n" + "addp v29.4s, v29.4s, v26.4s\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "fmla v0.4s, v29.4s, v16.4s\n" + "cbnz x22, 2b\n" + "sub %x[nc], %x[nc], #0x4\n" + "str q0, [%x[res_ptr], #0x0]\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23" + ); +#elif defined(__ARM_NEON) && defined(__aarch64__) + LM_GGML_ASSERT((lm_ggml_cpu_has_sve() || lm_ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " + "performance"); +#else + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +#endif +} + +void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) + if (svcntw() == 8) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "ptrue p0.b\n" + "add %x[b_ptr], %x[b_ptr], #0x10\n" + "1:" // Column loop + "add x22, %x[a_ptr], #0x2\n" + "mov z31.b, #0x0\n" + "mov x21, %x[nb]\n" + "2:" // Block loop + "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n" + "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n" + "mov z28.s, #0x0\n" + "mov z27.s, #0x0\n" + "ld1rd { z26.d }, p0/Z, [x22]\n" + "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n" + "sub x20, x22, #0x2\n" + "sub x21, x21, #0x1\n" + "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n" + "ld1rd { z23.d }, p0/Z, [x22, #8]\n" + "lsl z22.b, z30.b, #0x4\n" + "lsl z16.b, z29.b, #0x4\n" + "and z30.b, z30.b, #0xf0\n" + "and z29.b, z29.b, #0xf0\n" + "ld1rd { z21.d }, p0/Z, [x22, #16]\n" + "ld1rd { z20.d }, p0/Z, [x22, #24]\n" + "lsl z19.b, z25.b, #0x4\n" + "and z25.b, z25.b, #0xf0\n" + "ld1rh { z17.h }, p0/Z, [x20]\n" + "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n" + "sdot z28.s, z22.b, z26.b\n" + "sdot z27.s, z16.b, z26.b\n" + "lsl z16.b, z24.b, #0x4\n" + "add x22, x22, #0x22\n" + "and z24.b, z24.b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x90\n" + "fcvt z17.s, p0/m, z17.h\n" + "fcvt z18.s, p0/m, z18.h\n" + "sdot z28.s, z19.b, z23.b\n" + "sdot z27.s, z16.b, z23.b\n" + "fmul z18.s, z18.s, z17.s\n" + "sdot z28.s, z30.b, z21.b\n" + "sdot z27.s, z29.b, z21.b\n" + "sdot z28.s, z25.b, z20.b\n" + "sdot z27.s, z24.b, z20.b\n" + "uzp1 z17.s, z28.s, z27.s\n" + "uzp2 z16.s, z28.s, z27.s\n" + "add z17.s, z17.s, z16.s\n" + "asr z17.s, z17.s, #0x4\n" + "scvtf z17.s, p0/m, z17.s\n" + "fmla z31.s, p0/M, z17.s, z18.s\n" + "cbnz x21, 2b\n" + "sub %x[nc], %x[nc], #0x8\n" + "st1w { z31.s }, p0, [%x[res_ptr]]\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); + return; + } + else if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) { + LM_GGML_ASSERT((lm_ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal " + "performance"); + } + else if (lm_ggml_cpu_has_neon()) { + LM_GGML_ASSERT(((lm_ggml_cpu_has_sve() && (svcntw() == 8)) || lm_ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 " + "quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + LM_GGML_ASSERT(lm_ggml_cpu_has_sve() && + "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance"); +#elif defined(__ARM_NEON) && defined(__aarch64__) + LM_GGML_ASSERT((lm_ggml_cpu_has_sve() || lm_ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " + "performance"); +#else + float sumf[8]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +#endif +} + +void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntw() == 8) { + LM_GGML_ASSERT(!(lm_ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + LM_GGML_ASSERT(!(lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) && + "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance"); +#elif defined(__ARM_NEON) && defined(__aarch64__) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v23.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v0.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v8.16b, #0x0\n" + "movi v1.16b, #0x0\n" + "3:" // Block loop + "ldr q3, [x28, #0x0]\n" + "ldr q31, [x25, #0x0]\n" + "movi v28.16b, #0x4\n" + "movi v10.4s, #0x0\n" + "ldr q22, [x28, #0x10]\n" + "ldr q6, [x25, #0x10]\n" + "movi v29.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "ldr q27, [x28, #0x20]\n" + "ldr q30, [x28, #0x30]\n" + "movi v20.4s, #0x0\n" + "movi v24.16b, #0xf0\n" + "ldr d2, [x25, #-0x8]\n" + "ldr d26, [x23, #-0x8]\n" + "sshl v12.16b, v3.16b, v28.16b\n" + "sub x20, x28, #0x8\n" + "ldr d17, [x20, #0x0]\n" + "and v3.16b, v3.16b, v24.16b\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n" + ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n" + ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n" + ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n" + "sshl v31.16b, v22.16b, v28.16b\n" + "and v22.16b, v22.16b, v24.16b\n" + "fcvtl v17.4s, v17.4h\n" + "fcvtl v2.4s, v2.4h\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n" + ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n" + ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n" + ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n" + "sshl v6.16b, v27.16b, v28.16b\n" + "sshl v28.16b, v30.16b, v28.16b\n" + "and v27.16b, v27.16b, v24.16b\n" + "and v30.16b, v30.16b, v24.16b\n" + "ldr q24, [x25, #0x20]\n" + ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x30]\n" + ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n" + ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n" + ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n" + ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x40]\n" + ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x50]\n" + ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n" + ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n" + ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n" + ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x60]\n" + ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n" + ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n" + ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n" + ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n" + "fmul v24.4s, v17.4s, v2.s[0]\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v15.4s, v10.4s, v24.4s\n" + "ldr q24, [x23, #0x0]\n" + "fmul v10.4s, v17.4s, v2.s[1]\n" + "fmla v19.4s, v29.4s, v10.4s\n" + "ldr q10, [x23, #0x10]\n" + "fmul v29.4s, v17.4s, v2.s[2]\n" + "fmul v2.4s, v17.4s, v2.s[3]\n" + "fmla v18.4s, v9.4s, v29.4s\n" + "movi v9.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n" + "fmla v14.4s, v20.4s, v2.4s\n" + "movi v20.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x20]\n" + ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n" + ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n" + ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n" + ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x30]\n" + ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x40]\n" + ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n" + ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n" + ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n" + ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x50]\n" + ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x60]\n" + ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n" + ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n" + ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n" + ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x0]\n" + ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n" + ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n" + ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n" + ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n" + "fmul v10.4s, v17.4s, v26.s[0]\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v11.4s, v9.4s, v10.4s\n" + "ldr q9, [x22, #0x10]\n" + "fmul v10.4s, v17.4s, v26.s[1]\n" + "fmla v13.4s, v29.4s, v10.4s\n" + "ldr d29, [x22, #-0x8]\n" + "fmul v10.4s, v17.4s, v26.s[2]\n" + "fmul v26.4s, v17.4s, v26.s[3]\n" + "fcvtl v29.4s, v29.4h\n" + "fmla v23.4s, v20.4s, v10.4s\n" + "movi v20.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v16.4s, v2.4s, v26.4s\n" + "movi v26.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x20]\n" + ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x30]\n" + ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x40]\n" + ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n" + ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" + ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n" + ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x50]\n" + ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x60]\n" + ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n" + ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" + ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n" + ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x70]\n" + "add x22, x22, #0x88\n" + ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x21, #0x0]\n" + ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n" + ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n" + ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n" + "fmul v9.4s, v17.4s, v29.s[0]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v25.4s, v20.4s, v9.4s\n" + "ldr q9, [x21, #0x10]\n" + "fmul v20.4s, v17.4s, v29.s[1]\n" + "fmla v7.4s, v10.4s, v20.4s\n" + "ldr d20, [x21, #-0x8]\n" + "fmul v10.4s, v17.4s, v29.s[2]\n" + "fmul v29.4s, v17.4s, v29.s[3]\n" + "fcvtl v20.4s, v20.4h\n" + "fmla v0.4s, v26.4s, v10.4s\n" + "movi v26.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v4.4s, v2.4s, v29.4s\n" + "movi v2.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n" + "ldr q12, [x21, #0x20]\n" + "fmul v24.4s, v17.4s, v20.s[0]\n" + ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x21, #0x30]\n" + "fmul v31.4s, v17.4s, v20.s[1]\n" + ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n" + ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n" + ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n" + ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n" + "ldr q12, [x21, #0x40]\n" + "fmul v6.4s, v17.4s, v20.s[2]\n" + "fmul v20.4s, v17.4s, v20.s[3]\n" + ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n" + ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" + ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n" + ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n" + "ldr q9, [x21, #0x50]\n" + ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n" + ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n" + ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n" + ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n" + "ldr q12, [x21, #0x60]\n" + ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n" + ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" + ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n" + ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n" + "ldr q17, [x21, #0x70]\n" + "add x21, x21, #0x88\n" + ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n" + ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n" + ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n" + ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n" + ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n" + ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n" + ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n" + ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "fmla v5.4s, v26.4s, v24.4s\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "fmla v21.4s, v10.4s, v31.4s\n" + "fmla v8.4s, v2.4s, v6.4s\n" + "fmla v1.4s, v29.4s, v20.4s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x27, x27, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q14, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q11, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q16, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q0, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q21, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q8, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q1, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x10, x10, #0x10\n" + "cmp x10, #0x10\n" + "mov %x[res_ptr], x26\n" + "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x10, 9f\n" + "5:" // Row tail: Row loop + "add x24, %x[b_ptr], #0x8\n" + "mov x23, %x[nc]\n" + "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "add x25, %x[a_ptr], #0x8\n" + "mov x21, %x[nb]\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "7:" // Row tail: Block loop + "ldr q7, [x24, #0x0]\n" + "ldr q5, [x25, #0x0]\n" + "movi v9.16b, #0x4\n" + "movi v4.4s, #0x0\n" + "ldr q3, [x24, #0x10]\n" + "ldr q2, [x25, #0x10]\n" + "movi v1.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "ldr q13, [x24, #0x20]\n" + "ldr q31, [x25, #0x20]\n" + "movi v30.4s, #0x0\n" + "movi v29.16b, #0xf0\n" + "ldr q28, [x24, #0x30]\n" + "ldr q27, [x25, #0x30]\n" + "sshl v20.16b, v7.16b, v9.16b\n" + "sub x20, x24, #0x8\n" + "ldr q26, [x25, #0x40]\n" + "ldr q25, [x25, #0x50]\n" + "sshl v17.16b, v3.16b, v9.16b\n" + "and v7.16b, v7.16b, v29.16b\n" + "ldr q24, [x25, #0x60]\n" + "ldr q16, [x25, #0x70]\n" + "sshl v22.16b, v13.16b, v9.16b\n" + "and v3.16b, v3.16b, v29.16b\n" + "ldr d21, [x20, #0x0]\n" + "ldr d12, [x25, #-0x8]\n" + ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n" + ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n" + ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n" + ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n" + "sshl v9.16b, v28.16b, v9.16b\n" + "subs x21, x21, #0x1\n" + "and v13.16b, v13.16b, v29.16b\n" + "and v28.16b, v28.16b, v29.16b\n" + "add x25, x25, #0x88\n" + "add x24, x24, #0x48\n" + "fcvtl v21.4s, v21.4h\n" + "fcvtl v12.4s, v12.4h\n" + ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n" + ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n" + ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n" + ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n" + "fmul v11.4s, v21.4s, v12.s[0]\n" + "fmul v23.4s, v21.4s, v12.s[1]\n" + "fmul v17.4s, v21.4s, v12.s[2]\n" + ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n" + "fmul v6.4s, v21.4s, v12.s[3]\n" + ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n" + ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n" + ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n" + ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n" + ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n" + ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n" + ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n" + ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n" + ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n" + ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n" + ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n" + ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n" + ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n" + ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n" + ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n" + ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n" + ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n" + ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n" + ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n" + ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n" + ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n" + ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n" + ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n" + "scvtf v4.4s, v4.4s, #0x4\n" + "scvtf v1.4s, v1.4s, #0x4\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "fmla v15.4s, v4.4s, v11.4s\n" + "scvtf v30.4s, v30.4s, #0x4\n" + "fmla v19.4s, v1.4s, v23.4s\n" + "fmla v18.4s, v0.4s, v17.4s\n" + "fmla v14.4s, v30.4s, v6.4s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x10, #0x1\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x2\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x3\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "str q14, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "bne 6b\n" + "subs x10, x10, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x9\n" + "mov %x[res_ptr], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +#else + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +#endif +} + +void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntw() == 8) { + LM_GGML_ASSERT(!(lm_ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v2.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v12.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v6.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "3:" // Block loop + "ldr q21, [x28, #0x0]\n" + "ldr q16, [x28, #0x10]\n" + "movi v1.16b, #0x4\n" + "movi v19.4s, #0x0\n" + "ldr q27, [x25, #0x0]\n" + "ldr q15, [x25, #0x10]\n" + "movi v26.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "ldr q29, [x28, #0x20]\n" + "ldr q3, [x28, #0x30]\n" + "movi v17.4s, #0x0\n" + "movi v0.16b, #0xf0\n" + "ldr d20, [x25, #-0x8]\n" + "ldr d9, [x23, #-0x8]\n" + "sshl v8.16b, v21.16b, v1.16b\n" + "sshl v31.16b, v16.16b, v1.16b\n" + "and v21.16b, v21.16b, v0.16b\n" + "and v16.16b, v16.16b, v0.16b\n" + "sub x20, x28, #0x8\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n" + ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n" + "ldr q27, [x25, #0x20]\n" + ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n" + ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n" + "sshl v15.16b, v29.16b, v1.16b\n" + "sshl v1.16b, v3.16b, v1.16b\n" + "and v29.16b, v29.16b, v0.16b\n" + "and v3.16b, v3.16b, v0.16b\n" + "ldr q0, [x25, #0x30]\n" + "fcvtl v20.4s, v20.4h\n" + ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n" + "fcvtl v9.4s, v9.4h\n" + ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n" + "ldr q27, [x25, #0x40]\n" + ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n" + ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n" + "ldr q0, [x25, #0x50]\n" + ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n" + ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n" + "ldr q27, [x25, #0x60]\n" + ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n" + ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n" + "ldr q0, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n" + ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n" + "ldr d27, [x20, #0x0]\n" + ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n" + ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n" + "fcvtl v27.4s, v27.4h\n" + "uzp1 v0.2d, v19.2d, v26.2d\n" + "uzp2 v26.2d, v19.2d, v26.2d\n" + "fmul v19.4s, v27.4s, v20.s[0]\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "fmla v2.4s, v0.4s, v19.4s\n" + "ldr q19, [x23, #0x0]\n" + "uzp1 v0.2d, v18.2d, v17.2d\n" + "uzp2 v18.2d, v18.2d, v17.2d\n" + "fmul v17.4s, v27.4s, v20.s[1]\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v10.4s, v26.4s, v17.4s\n" + "ldr q17, [x23, #0x10]\n" + "fmul v26.4s, v27.4s, v20.s[2]\n" + "fmul v20.4s, v27.4s, v20.s[3]\n" + "fmla v12.4s, v0.4s, v26.4s\n" + "ldr d0, [x22, #-0x8]\n" + "ldr d26, [x21, #-0x8]\n" + "fcvtl v0.4s, v0.4h\n" + "fmla v28.4s, v18.4s, v20.4s\n" + "movi v20.4s, #0x0\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" + ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" + "ldr q19, [x23, #0x20]\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" + ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" + "ldr q19, [x23, #0x40]\n" + ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" + ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" + "ldr q19, [x23, #0x60]\n" + ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n" + ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n" + "uzp1 v19.2d, v20.2d, v18.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp2 v20.2d, v20.2d, v18.2d\n" + "fmul v18.4s, v27.4s, v9.s[0]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v11.4s, v19.4s, v18.4s\n" + "ldr q18, [x22, #0x0]\n" + "fmul v19.4s, v27.4s, v9.s[1]\n" + "fmla v13.4s, v20.4s, v19.4s\n" + "movi v19.4s, #0x0\n" + "movi v20.4s, #0x0\n" + ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n" + ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n" + "ldr q17, [x23, #0x30]\n" + ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n" + ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n" + "ldr q17, [x23, #0x50]\n" + ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n" + ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n" + "ldr q17, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n" + ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n" + "uzp1 v17.2d, v19.2d, v20.2d\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "uzp2 v20.2d, v19.2d, v20.2d\n" + "fmul v19.4s, v27.4s, v9.s[2]\n" + "fmul v9.4s, v27.4s, v9.s[3]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v22.4s, v17.4s, v19.4s\n" + "ldr q17, [x22, #0x10]\n" + "movi v19.4s, #0x0\n" + ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n" + "fmla v23.4s, v20.4s, v9.4s\n" + "movi v20.4s, #0x0\n" + "movi v9.4s, #0x0\n" + ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n" + "ldr q18, [x22, #0x20]\n" + ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" + ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n" + ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n" + "ldr q18, [x22, #0x40]\n" + ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n" + ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n" + "ldr q18, [x22, #0x60]\n" + ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n" + ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n" + "ldr q17, [x22, #0x30]\n" + ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" + ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n" + "ldr q17, [x22, #0x50]\n" + ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n" + ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n" + "ldr q17, [x22, #0x70]\n" + "add x22, x22, #0x88\n" + ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n" + ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n" + "uzp1 v17.2d, v19.2d, v20.2d\n" + "uzp2 v20.2d, v19.2d, v20.2d\n" + "fmul v19.4s, v27.4s, v0.s[0]\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v25.4s, v17.4s, v19.4s\n" + "ldr q19, [x21, #0x0]\n" + "fmul v17.4s, v27.4s, v0.s[1]\n" + "fmla v5.4s, v20.4s, v17.4s\n" + "ldr q17, [x21, #0x10]\n" + "uzp1 v20.2d, v9.2d, v18.2d\n" + "uzp2 v9.2d, v9.2d, v18.2d\n" + "fmul v18.4s, v27.4s, v0.s[2]\n" + "fmul v0.4s, v27.4s, v0.s[3]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "fmla v7.4s, v20.4s, v18.4s\n" + "movi v20.4s, #0x0\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" + ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" + "ldr q19, [x21, #0x20]\n" + "fmla v4.4s, v9.4s, v0.4s\n" + "movi v9.4s, #0x0\n" + "movi v0.4s, #0x0\n" + ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" + "fmul v8.4s, v27.4s, v26.s[0]\n" + ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n" + "ldr q17, [x21, #0x30]\n" + ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" + "fmul v31.4s, v27.4s, v26.s[1]\n" + ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" + "ldr q19, [x21, #0x40]\n" + ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" + "fmul v15.4s, v27.4s, v26.s[2]\n" + "fmul v27.4s, v27.4s, v26.s[3]\n" + ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n" + "ldr q1, [x21, #0x50]\n" + ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" + ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" + "ldr q26, [x21, #0x60]\n" + ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n" + ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n" + "ldr q21, [x21, #0x70]\n" + "add x21, x21, #0x88\n" + ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n" + ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n" + ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n" + ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n" + "uzp1 v29.2d, v20.2d, v18.2d\n" + "uzp2 v21.2d, v20.2d, v18.2d\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "uzp1 v18.2d, v9.2d, v0.2d\n" + "uzp2 v16.2d, v9.2d, v0.2d\n" + "scvtf v21.4s, v21.4s, #0x4\n" + "fmla v6.4s, v29.4s, v8.4s\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v30.4s, v21.4s, v31.4s\n" + "fmla v24.4s, v18.4s, v15.4s\n" + "fmla v14.4s, v16.4s, v27.4s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x27, x27, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q28, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q11, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q6, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q24, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q14, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x10, x10, #0x10\n" + "cmp x10, #0x10\n" + "mov %x[res_ptr], x26\n" + "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x10, 9f\n" + "5:" // Row tail: Row loop + "add x24, %x[b_ptr], #0x8\n" + "mov x23, %x[nc]\n" + "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "movi v2.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "add x25, %x[a_ptr], #0x8\n" + "mov x21, %x[nb]\n" + "movi v12.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "7:" // Row tail: Block loop + "ldr q6, [x24, #0x0]\n" + "ldr q5, [x24, #0x10]\n" + "movi v17.16b, #0x4\n" + "movi v8.4s, #0x0\n" + "ldr q4, [x25, #0x0]\n" + "ldr q13, [x25, #0x10]\n" + "movi v27.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "ldr q31, [x24, #0x20]\n" + "ldr q14, [x24, #0x30]\n" + "movi v29.4s, #0x0\n" + "movi v22.16b, #0xf0\n" + "ldr q11, [x25, #0x20]\n" + "ldr q23, [x25, #0x30]\n" + "sshl v21.16b, v6.16b, v17.16b\n" + "sshl v16.16b, v5.16b, v17.16b\n" + "ldr q20, [x25, #0x40]\n" + "ldr q26, [x25, #0x50]\n" + "and v6.16b, v6.16b, v22.16b\n" + "and v5.16b, v5.16b, v22.16b\n" + "ldr q25, [x25, #0x60]\n" + "ldr q3, [x25, #0x70]\n" + "sshl v19.16b, v31.16b, v17.16b\n" + "sshl v18.16b, v14.16b, v17.16b\n" + "ldr d17, [x25, #-0x8]\n" + ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n" + ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n" + "and v31.16b, v31.16b, v22.16b\n" + ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n" + ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n" + "and v14.16b, v14.16b, v22.16b\n" + "sub x20, x24, #0x8\n" + "ldr d16, [x20, #0x0]\n" + "subs x21, x21, #0x1\n" + "add x25, x25, #0x88\n" + "fcvtl v17.4s, v17.4h\n" + "add x24, x24, #0x48\n" + ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n" + ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n" + ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n" + ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n" + "fcvtl v16.4s, v16.4h\n" + ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n" + ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n" + "fmul v23.4s, v16.4s, v17.s[0]\n" + "fmul v21.4s, v16.4s, v17.s[1]\n" + "fmul v1.4s, v16.4s, v17.s[2]\n" + "fmul v20.4s, v16.4s, v17.s[3]\n" + ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n" + ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n" + ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n" + ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n" + ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n" + ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n" + "uzp1 v19.2d, v8.2d, v27.2d\n" + "uzp2 v18.2d, v8.2d, v27.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp1 v17.2d, v0.2d, v29.2d\n" + "uzp2 v16.2d, v0.2d, v29.2d\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v2.4s, v19.4s, v23.4s\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v10.4s, v18.4s, v21.4s\n" + "fmla v12.4s, v17.4s, v1.4s\n" + "fmla v28.4s, v16.4s, v20.4s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x10, #0x1\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x2\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x3\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "str q28, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "bne 6b\n" + "subs x10, x10, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x9\n" + "mov %x[res_ptr], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +#elif defined(__ARM_NEON) && defined(__aarch64__) + LM_GGML_ASSERT((lm_ggml_cpu_has_sve() || lm_ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " + "performance"); +#else + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +#endif +} + +void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntw() == 8) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x20, #0x4\n" + "mov x13, %x[nr]\n" + "mov z28.s, #-0x4\n" + "mov x12, #0x88\n" + "ptrue p1.b\n" + "whilelt p0.s, XZR, x20\n" + "cmp x13, #0x10\n" + "mul x12, %x[nb], x12\n" + "blt 4f\n" + "1:" // Row loop + "add x11, %x[b_ptr], #0x10\n" + "mov x10, %x[nc]\n" + "add x9, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x28, %x[a_ptr], #0x8\n" + "mov z24.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov x27, %x[nb]\n" + "add x26, x28, x12\n" + "mov z12.b, #0x0\n" + "mov z0.b, #0x0\n" + "add x25, x26, x12\n" + "mov z13.b, #0x0\n" + "mov z1.b, #0x0\n" + "add x24, x25, x12\n" + "mov z20.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z8.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z10.b, #0x0\n" + "3:" // Block loop + "ld1b { z30.b }, p1/Z, [x11]\n" + "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n" + "mov z18.s, #0x0\n" + "mov z7.s, #0x0\n" + "ld1rqb { z3.b }, p1/Z, [x28]\n" + "ld1rqb { z5.b }, p1/Z, [x28, #16]\n" + "mov z9.s, #0x0\n" + "mov z22.s, #0x0\n" + "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n" + "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n" + "sub x20, x11, #0x10\n" + "sub x23, x28, #0x8\n" + "lsl z31.b, z30.b, #0x4\n" + "lsl z6.b, z21.b, #0x4\n" + "ld1h { z23.s }, p1/Z, [x20]\n" + "sub x22, x26, #0x8\n" + "and z30.b, z30.b, #0xf0\n" + "and z21.b, z21.b, #0xf0\n" + "sub x21, x25, #0x8\n" + "sub x20, x24, #0x8\n" + "lsl z14.b, z4.b, #0x4\n" + "lsl z2.b, z17.b, #0x4\n" + "subs x27, x27, #0x1\n" + "add x11, x11, #0x90\n" + ".inst 0x451f9872 // smmla z18.s, z3.b, z31.b\n" + ".inst 0x45069867 // smmla z7.s, z3.b, z6.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #32]\n" + "and z4.b, z4.b, #0xf0\n" + ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" + ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #48]\n" + "and z17.b, z17.b, #0xf0\n" + "fcvt z23.s, p1/m, z23.h\n" + ".inst 0x450e9872 // smmla z18.s, z3.b, z14.b\n" + ".inst 0x45029867 // smmla z7.s, z3.b, z2.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #64]\n" + ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" + ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #80]\n" + "fscale z23.s, p1/m, z23.s, z28.s\n" + ".inst 0x451e9872 // smmla z18.s, z3.b, z30.b\n" + ".inst 0x45159867 // smmla z7.s, z3.b, z21.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #96]\n" + ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" + ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #112]\n" + "add x28, x28, #0x88\n" + ".inst 0x45049872 // smmla z18.s, z3.b, z4.b\n" + ".inst 0x45119867 // smmla z7.s, z3.b, z17.b\n" + "ld1h { z3.s }, p0/Z, [x23]\n" + ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" + ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" + "fcvt z3.s, p1/m, z3.h\n" + "uzp1 z5.d, z18.d, z7.d\n" + "uzp2 z18.d, z18.d, z7.d\n" + "mov z3.q, z3.q[0]\n" + "uzp1 z7.d, z9.d, z22.d\n" + "uzp2 z22.d, z9.d, z22.d\n" + "fmul z9.s, z23.s, z3.s[0]\n" + "scvtf z5.s, p1/m, z5.s\n" + "scvtf z18.s, p1/m, z18.s\n" + "scvtf z7.s, p1/m, z7.s\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z24.s, p1/M, z5.s, z9.s\n" + "ld1rqb { z5.b }, p1/Z, [x26]\n" + "fmul z9.s, z23.s, z3.s[1]\n" + "fmla z15.s, p1/M, z18.s, z9.s\n" + "ld1rqb { z18.b }, p1/Z, [x26, #16]\n" + "fmul z9.s, z23.s, z3.s[2]\n" + "fmul z3.s, z23.s, z3.s[3]\n" + "fmla z12.s, p1/M, z7.s, z9.s\n" + "mov z9.s, #0x0\n" + "ld1h { z7.s }, p0/Z, [x22]\n" + ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" + "fmla z0.s, p1/M, z22.s, z3.s\n" + "mov z22.s, #0x0\n" + "ld1h { z3.s }, p0/Z, [x21]\n" + ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #32]\n" + "fcvt z7.s, p1/m, z7.h\n" + "fcvt z3.s, p1/m, z3.h\n" + ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" + ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #64]\n" + "mov z7.q, z7.q[0]\n" + "mov z3.q, z3.q[0]\n" + ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" + ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #96]\n" + ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" + ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" + "uzp1 z5.d, z9.d, z22.d\n" + "scvtf z5.s, p1/m, z5.s\n" + "uzp2 z22.d, z9.d, z22.d\n" + "fmul z9.s, z23.s, z7.s[0]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z13.s, p1/M, z5.s, z9.s\n" + "ld1rqb { z9.b }, p1/Z, [x25]\n" + "fmul z5.s, z23.s, z7.s[1]\n" + "fmla z1.s, p1/M, z22.s, z5.s\n" + "mov z5.s, #0x0\n" + "mov z22.s, #0x0\n" + ".inst 0x451f9a45 // smmla z5.s, z18.b, z31.b\n" + ".inst 0x45069a56 // smmla z22.s, z18.b, z6.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #48]\n" + ".inst 0x450e9a45 // smmla z5.s, z18.b, z14.b\n" + ".inst 0x45029a56 // smmla z22.s, z18.b, z2.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #80]\n" + ".inst 0x451e9a45 // smmla z5.s, z18.b, z30.b\n" + ".inst 0x45159a56 // smmla z22.s, z18.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #112]\n" + "add x26, x26, #0x88\n" + ".inst 0x45049a45 // smmla z5.s, z18.b, z4.b\n" + ".inst 0x45119a56 // smmla z22.s, z18.b, z17.b\n" + "uzp1 z18.d, z5.d, z22.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp2 z22.d, z5.d, z22.d\n" + "fmul z5.s, z23.s, z7.s[2]\n" + "fmul z7.s, z23.s, z7.s[3]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z20.s, p1/M, z18.s, z5.s\n" + "ld1rqb { z18.b }, p1/Z, [x25, #16]\n" + "ld1h { z5.s }, p0/Z, [x20]\n" + "fcvt z5.s, p1/m, z5.h\n" + "fmla z25.s, p1/M, z22.s, z7.s\n" + "mov z22.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9936 // smmla z22.s, z9.b, z31.b\n" + ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #32]\n" + "mov z5.q, z5.q[0]\n" + ".inst 0x450e9936 // smmla z22.s, z9.b, z14.b\n" + ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #64]\n" + ".inst 0x451e9936 // smmla z22.s, z9.b, z30.b\n" + ".inst 0x45159927 // smmla z7.s, z9.b, z21.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #96]\n" + ".inst 0x45049936 // smmla z22.s, z9.b, z4.b\n" + ".inst 0x45119927 // smmla z7.s, z9.b, z17.b\n" + "uzp1 z9.d, z22.d, z7.d\n" + "scvtf z9.s, p1/m, z9.s\n" + "uzp2 z22.d, z22.d, z7.d\n" + "fmul z7.s, z23.s, z3.s[0]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z11.s, p1/M, z9.s, z7.s\n" + "ld1rqb { z9.b }, p1/Z, [x24]\n" + "fmul z7.s, z23.s, z3.s[1]\n" + "fmla z16.s, p1/M, z22.s, z7.s\n" + "mov z22.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9a56 // smmla z22.s, z18.b, z31.b\n" + ".inst 0x45069a47 // smmla z7.s, z18.b, z6.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #48]\n" + ".inst 0x450e9a56 // smmla z22.s, z18.b, z14.b\n" + ".inst 0x45029a47 // smmla z7.s, z18.b, z2.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #80]\n" + ".inst 0x451e9a56 // smmla z22.s, z18.b, z30.b\n" + ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #112]\n" + "add x25, x25, #0x88\n" + ".inst 0x45049a56 // smmla z22.s, z18.b, z4.b\n" + ".inst 0x45119a47 // smmla z7.s, z18.b, z17.b\n" + "uzp1 z18.d, z22.d, z7.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp2 z7.d, z22.d, z7.d\n" + "fmul z22.s, z23.s, z3.s[2]\n" + "fmul z3.s, z23.s, z3.s[3]\n" + "scvtf z7.s, p1/m, z7.s\n" + "fmla z19.s, p1/M, z18.s, z22.s\n" + "ld1rqb { z18.b }, p1/Z, [x24, #16]\n" + "fmul z22.s, z23.s, z5.s[0]\n" + "fmla z26.s, p1/M, z7.s, z3.s\n" + "mov z3.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9923 // smmla z3.s, z9.b, z31.b\n" + ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" + "ld1rqb { z9.b }, p1/Z, [x24, #32]\n" + ".inst 0x450e9923 // smmla z3.s, z9.b, z14.b\n" + ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" + "mov z9.s, #0x0\n" + ".inst 0x451f9a49 // smmla z9.s, z18.b, z31.b\n" + "mov z31.s, #0x0\n" + ".inst 0x45069a5f // smmla z31.s, z18.b, z6.b\n" + "ld1rqb { z6.b }, p1/Z, [x24, #48]\n" + "ld1rqb { z18.b }, p1/Z, [x24, #64]\n" + ".inst 0x450e98c9 // smmla z9.s, z6.b, z14.b\n" + "fmul z14.s, z23.s, z5.s[1]\n" + ".inst 0x450298df // smmla z31.s, z6.b, z2.b\n" + "ld1rqb { z6.b }, p1/Z, [x24, #80]\n" + "fmul z2.s, z23.s, z5.s[2]\n" + "fmul z23.s, z23.s, z5.s[3]\n" + ".inst 0x451e9a43 // smmla z3.s, z18.b, z30.b\n" + ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x24, #96]\n" + ".inst 0x451e98c9 // smmla z9.s, z6.b, z30.b\n" + ".inst 0x451598df // smmla z31.s, z6.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x24, #112]\n" + "add x24, x24, #0x88\n" + ".inst 0x450498a3 // smmla z3.s, z5.b, z4.b\n" + ".inst 0x451198a7 // smmla z7.s, z5.b, z17.b\n" + ".inst 0x45049a49 // smmla z9.s, z18.b, z4.b\n" + ".inst 0x45119a5f // smmla z31.s, z18.b, z17.b\n" + "uzp1 z18.d, z3.d, z7.d\n" + "uzp2 z5.d, z3.d, z7.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp1 z6.d, z9.d, z31.d\n" + "uzp2 z9.d, z9.d, z31.d\n" + "scvtf z5.s, p1/m, z5.s\n" + "fmla z8.s, p1/M, z18.s, z22.s\n" + "scvtf z6.s, p1/m, z6.s\n" + "scvtf z9.s, p1/m, z9.s\n" + "fmla z29.s, p1/M, z5.s, z14.s\n" + "fmla z27.s, p1/M, z6.s, z2.s\n" + "fmla z10.s, p1/M, z9.s, z23.s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x10, x10, #0x8\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "st1w { z24.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z15.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z12.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z0.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z13.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z1.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z20.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z25.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z11.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z16.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z19.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z26.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z8.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z29.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z27.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z10.s }, p1, [x20]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x13, x13, #0x10\n" + "cmp x13, #0x10\n" + "mov %x[res_ptr], x9\n" + "madd %x[a_ptr], x20, x12, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x13, 9f\n" + "5:" // Row tail: Row loop + "add x25, %x[b_ptr], #0x10\n" + "mov x24, %x[nc]\n" + "add x23, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "mov z24.b, #0x0\n" + "mov z15.b, #0x0\n" + "add x28, %x[a_ptr], #0x8\n" + "mov x22, %x[nb]\n" + "mov z12.b, #0x0\n" + "mov z0.b, #0x0\n" + "7:" // Row tail: Block loop + "ld1b { z3.b }, p1/Z, [x25]\n" + "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n" + "mov z2.s, #0x0\n" + "mov z25.s, #0x0\n" + "ld1rqb { z26.b }, p1/Z, [x28]\n" + "ld1rqb { z21.b }, p1/Z, [x28, #16]\n" + "mov z27.s, #0x0\n" + "mov z19.s, #0x0\n" + "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n" + "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n" + "sub x21, x25, #0x10\n" + "sub x20, x28, #0x8\n" + "lsl z20.b, z3.b, #0x4\n" + "lsl z4.b, z6.b, #0x4\n" + "ld1rqb { z10.b }, p1/Z, [x28, #32]\n" + "ld1rqb { z23.b }, p1/Z, [x28, #48]\n" + "and z3.b, z3.b, #0xf0\n" + "and z6.b, z6.b, #0xf0\n" + "ld1rqb { z11.b }, p1/Z, [x28, #64]\n" + "ld1rqb { z7.b }, p1/Z, [x28, #80]\n" + "lsl z8.b, z29.b, #0x4\n" + "lsl z14.b, z16.b, #0x4\n" + "ld1rqb { z18.b }, p1/Z, [x28, #96]\n" + "ld1rqb { z30.b }, p1/Z, [x28, #112]\n" + ".inst 0x45149b42 // smmla z2.s, z26.b, z20.b\n" + ".inst 0x45049b59 // smmla z25.s, z26.b, z4.b\n" + "and z29.b, z29.b, #0xf0\n" + "ld1h { z17.s }, p1/Z, [x21]\n" + ".inst 0x45149abb // smmla z27.s, z21.b, z20.b\n" + ".inst 0x45049ab3 // smmla z19.s, z21.b, z4.b\n" + "and z16.b, z16.b, #0xf0\n" + "ld1h { z4.s }, p0/Z, [x20]\n" + "subs x22, x22, #0x1\n" + "add x28, x28, #0x88\n" + "fcvt z17.s, p1/m, z17.h\n" + "add x25, x25, #0x90\n" + ".inst 0x45089942 // smmla z2.s, z10.b, z8.b\n" + ".inst 0x450e9959 // smmla z25.s, z10.b, z14.b\n" + "fcvt z4.s, p1/m, z4.h\n" + ".inst 0x45089afb // smmla z27.s, z23.b, z8.b\n" + ".inst 0x450e9af3 // smmla z19.s, z23.b, z14.b\n" + "fscale z17.s, p1/m, z17.s, z28.s\n" + "mov z4.q, z4.q[0]\n" + ".inst 0x45039962 // smmla z2.s, z11.b, z3.b\n" + ".inst 0x45069979 // smmla z25.s, z11.b, z6.b\n" + "fmul z23.s, z17.s, z4.s[0]\n" + "fmul z9.s, z17.s, z4.s[1]\n" + "fmul z21.s, z17.s, z4.s[2]\n" + "fmul z4.s, z17.s, z4.s[3]\n" + ".inst 0x450398fb // smmla z27.s, z7.b, z3.b\n" + ".inst 0x450698f3 // smmla z19.s, z7.b, z6.b\n" + ".inst 0x451d9a42 // smmla z2.s, z18.b, z29.b\n" + ".inst 0x45109a59 // smmla z25.s, z18.b, z16.b\n" + ".inst 0x451d9bdb // smmla z27.s, z30.b, z29.b\n" + ".inst 0x45109bd3 // smmla z19.s, z30.b, z16.b\n" + "uzp1 z31.d, z2.d, z25.d\n" + "uzp2 z13.d, z2.d, z25.d\n" + "scvtf z31.s, p1/m, z31.s\n" + "uzp1 z17.d, z27.d, z19.d\n" + "uzp2 z18.d, z27.d, z19.d\n" + "scvtf z13.s, p1/m, z13.s\n" + "fmla z24.s, p1/M, z31.s, z23.s\n" + "scvtf z17.s, p1/m, z17.s\n" + "scvtf z18.s, p1/m, z18.s\n" + "fmla z15.s, p1/M, z13.s, z9.s\n" + "fmla z12.s, p1/M, z17.s, z21.s\n" + "fmla z0.s, p1/M, z18.s, z4.s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x13, #0x1\n" + "st1w { z24.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x13, #0x2\n" + "st1w { z15.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x13, #0x3\n" + "st1w { z12.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "st1w { z0.s }, p1, [x20]\n" + "8:" // Row tail: Accumulator store skip + "subs x24, x24, #0x8\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "bne 6b\n" + "subs x13, x13, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x12\n" + "mov %x[res_ptr], x23\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); + return; + } + else if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) { + LM_GGML_ASSERT((lm_ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal " + "performance"); + } + else if (lm_ggml_cpu_has_neon()) { + LM_GGML_ASSERT(((lm_ggml_cpu_has_sve() && (svcntw() == 8)) || lm_ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 " + "quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + LM_GGML_ASSERT(lm_ggml_cpu_has_sve() && + "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance"); +#elif defined(__ARM_NEON) && defined(__aarch64__) + LM_GGML_ASSERT((lm_ggml_cpu_has_sve() || lm_ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " + "performance"); +#else + float sumf[4][8]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +#endif +} diff --git a/cpp/ggml-aarch64.h b/cpp/ggml-aarch64.h new file mode 100644 index 00000000..90a23d77 --- /dev/null +++ b/cpp/ggml-aarch64.h @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. +#pragma once + +#define LM_GGML_COMMON_DECL_C +#include "ggml-common.h" + +#include "ggml.h" + +// GGML internal header + +#ifdef __cplusplus +extern "C" { +#endif + +// Quantization +void quantize_q8_0_4x4(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); +void quantize_q8_0_4x8(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); + +void quantize_mat_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave); + +// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") +size_t quantize_q4_0_4x4(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_0_4x8(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_0_8x8(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + +// GEMV +void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); +void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); +void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); + +// GEMM +void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); +void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); +void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); + +#ifdef __cplusplus +} +#endif + diff --git a/cpp/ggml-alloc.c b/cpp/ggml-alloc.c index 57f5f88a..e2195f04 100644 --- a/cpp/ggml-alloc.c +++ b/cpp/ggml-alloc.c @@ -339,6 +339,7 @@ struct hash_node { }; struct tensor_alloc { + int buffer_id; size_t offset; size_t size_max; // 0 = pre-allocated, unused, or view }; @@ -349,7 +350,6 @@ struct leaf_alloc { }; struct node_alloc { - int buffer_id; struct tensor_alloc dst; struct tensor_alloc src[LM_GGML_MAX_SRC]; }; @@ -377,7 +377,7 @@ lm_ggml_gallocr_t lm_ggml_gallocr_new_n(lm_ggml_backend_buffer_type_t * bufts, i galloc->bufts = calloc(n_bufs, sizeof(lm_ggml_backend_buffer_type_t)); LM_GGML_ASSERT(galloc->bufts != NULL); - galloc->buffers = calloc(n_bufs, sizeof(lm_ggml_backend_buffer_t) * n_bufs); + galloc->buffers = calloc(n_bufs, sizeof(lm_ggml_backend_buffer_t)); LM_GGML_ASSERT(galloc->buffers != NULL); galloc->buf_tallocs = calloc(n_bufs, sizeof(struct lm_ggml_dyn_tallocr *)); @@ -386,8 +386,19 @@ lm_ggml_gallocr_t lm_ggml_gallocr_new_n(lm_ggml_backend_buffer_type_t * bufts, i for (int i = 0; i < n_bufs; i++) { galloc->bufts[i] = bufts[i]; galloc->buffers[i] = NULL; - size_t alignment = lm_ggml_backend_buft_get_alignment(bufts[i]); - galloc->buf_tallocs[i] = lm_ggml_dyn_tallocr_new(alignment); + + // check if the same buffer type is used multiple times and reuse the same allocator + for (int j = 0; j < i; j++) { + if (bufts[i] == bufts[j]) { + galloc->buf_tallocs[i] = galloc->buf_tallocs[j]; + break; + } + } + + if (galloc->buf_tallocs[i] == NULL) { + size_t alignment = lm_ggml_backend_buft_get_alignment(bufts[i]); + galloc->buf_tallocs[i] = lm_ggml_dyn_tallocr_new(alignment); + } } galloc->n_buffers = n_bufs; @@ -405,10 +416,30 @@ void lm_ggml_gallocr_free(lm_ggml_gallocr_t galloc) { for (int i = 0; i < galloc->n_buffers; i++) { if (galloc->buffers != NULL) { - lm_ggml_backend_buffer_free(galloc->buffers[i]); + // skip if already freed + bool freed = false; + for (int j = 0; j < i; j++) { + if (galloc->buffers[j] == galloc->buffers[i]) { + freed = true; + break; + } + } + if (!freed) { + lm_ggml_backend_buffer_free(galloc->buffers[i]); + } } if (galloc->buf_tallocs != NULL) { - lm_ggml_dyn_tallocr_free(galloc->buf_tallocs[i]); + // skip if already freed + bool freed = false; + for (int j = 0; j < i; j++) { + if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) { + freed = true; + break; + } + } + if (!freed) { + lm_ggml_dyn_tallocr_free(galloc->buf_tallocs[i]); + } } } @@ -511,17 +542,18 @@ static void lm_ggml_gallocr_allocate_node(lm_ggml_gallocr_t galloc, struct lm_gg } } -static void lm_ggml_gallocr_free_node(lm_ggml_gallocr_t galloc, struct lm_ggml_tensor * node, int buffer_id) { +static void lm_ggml_gallocr_free_node(lm_ggml_gallocr_t galloc, struct lm_ggml_tensor * node) { // graph outputs are never freed if (node->flags & LM_GGML_TENSOR_FLAG_OUTPUT) { AT_PRINTF("not freeing output %s\n", node->name); return; } - struct lm_ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id]; - lm_ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id]; struct hash_node * hn = lm_ggml_gallocr_hash_get(galloc, node); size_t offset = hn->offset; + int buffer_id = hn->buffer_id; + struct lm_ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id]; + lm_ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id]; size_t size = lm_ggml_backend_buft_get_alloc_size(buft, node); lm_ggml_dyn_tallocr_free_tensor(alloc, offset, size, node); hn->allocated = false; @@ -626,11 +658,11 @@ static void lm_ggml_gallocr_alloc_graph_impl(lm_ggml_gallocr_t galloc, struct lm AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views); if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src_hn->allocated) { - lm_ggml_gallocr_free_node(galloc, view_src, buffer_id); + lm_ggml_gallocr_free_node(galloc, view_src); } } else if (p_hn->allocated) { - lm_ggml_gallocr_free_node(galloc, parent, buffer_id); + lm_ggml_gallocr_free_node(galloc, parent); } } AT_PRINTF("\n"); @@ -674,22 +706,25 @@ bool lm_ggml_gallocr_reserve_n(lm_ggml_gallocr_t galloc, struct lm_ggml_cgraph * for (int i = 0; i < graph->n_nodes; i++) { struct lm_ggml_tensor * node = graph->nodes[i]; struct node_alloc * node_alloc = &galloc->node_allocs[i]; - node_alloc->buffer_id = get_node_buffer_id(node_buffer_ids, i); if (node->view_src || node->data) { + node_alloc->dst.buffer_id = -1; node_alloc->dst.offset = SIZE_MAX; node_alloc->dst.size_max = 0; } else { struct hash_node * hn = lm_ggml_gallocr_hash_get(galloc, node); - node_alloc->dst.offset = hn->offset; - node_alloc->dst.size_max = lm_ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node); + node_alloc->dst.buffer_id = hn->buffer_id; + node_alloc->dst.offset = hn->offset; + node_alloc->dst.size_max = lm_ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node); } for (int j = 0; j < LM_GGML_MAX_SRC; j++) { struct lm_ggml_tensor * src = node->src[j]; if (!src || src->view_src || src->data) { + node_alloc->src[j].buffer_id = -1; node_alloc->src[j].offset = SIZE_MAX; node_alloc->src[j].size_max = 0; } else { struct hash_node * hn = lm_ggml_gallocr_hash_get(galloc, src); + node_alloc->src[j].buffer_id = hn->buffer_id; node_alloc->src[j].offset = hn->offset; node_alloc->src[j].size_max = lm_ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], src); } @@ -706,9 +741,11 @@ bool lm_ggml_gallocr_reserve_n(lm_ggml_gallocr_t galloc, struct lm_ggml_cgraph * struct hash_node * hn = lm_ggml_gallocr_hash_get(galloc, leaf); galloc->leaf_allocs[i].buffer_id = hn->buffer_id; if (leaf->view_src || leaf->data) { + galloc->leaf_allocs[i].leaf.buffer_id = -1; galloc->leaf_allocs[i].leaf.offset = SIZE_MAX; galloc->leaf_allocs[i].leaf.size_max = 0; } else { + galloc->leaf_allocs[i].leaf.buffer_id = hn->buffer_id; galloc->leaf_allocs[i].leaf.offset = hn->offset; galloc->leaf_allocs[i].leaf.size_max = lm_ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], leaf); } @@ -716,6 +753,14 @@ bool lm_ggml_gallocr_reserve_n(lm_ggml_gallocr_t galloc, struct lm_ggml_cgraph * // reallocate buffers if needed for (int i = 0; i < galloc->n_buffers; i++) { + // if the buffer type is used multiple times, we reuse the same buffer + for (int j = 0; j < i; j++) { + if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) { + galloc->buffers[i] = galloc->buffers[j]; + break; + } + } + size_t cur_size = galloc->buffers[i] ? lm_ggml_backend_buffer_get_size(galloc->buffers[i]) : 0; size_t new_size = lm_ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]); @@ -724,12 +769,14 @@ bool lm_ggml_gallocr_reserve_n(lm_ggml_gallocr_t galloc, struct lm_ggml_cgraph * #ifndef NDEBUG fprintf(stderr, "%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, lm_ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); #endif + lm_ggml_backend_buffer_free(galloc->buffers[i]); galloc->buffers[i] = lm_ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size); if (galloc->buffers[i] == NULL) { fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, lm_ggml_backend_buft_name(galloc->bufts[i]), new_size); return false; } + lm_ggml_backend_buffer_set_usage(galloc->buffers[i], LM_GGML_BACKEND_BUFFER_USAGE_COMPUTE); } } @@ -740,7 +787,8 @@ bool lm_ggml_gallocr_reserve(lm_ggml_gallocr_t galloc, struct lm_ggml_cgraph *gr return lm_ggml_gallocr_reserve_n(galloc, graph, NULL, NULL); } -static void lm_ggml_gallocr_init_tensor(lm_ggml_gallocr_t galloc, struct lm_ggml_tensor * tensor, int buffer_id, struct tensor_alloc * tensor_alloc) { +static void lm_ggml_gallocr_init_tensor(lm_ggml_gallocr_t galloc, struct lm_ggml_tensor * tensor, struct tensor_alloc * tensor_alloc) { + int buffer_id = tensor_alloc->buffer_id; assert(tensor->data || tensor->view_src || lm_ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max); if (tensor->view_src != NULL) { @@ -750,7 +798,7 @@ static void lm_ggml_gallocr_init_tensor(lm_ggml_gallocr_t galloc, struct lm_ggml // this tensor was allocated without ggml-backend return; } - lm_ggml_backend_view_init(galloc->buffers[buffer_id], tensor); + lm_ggml_backend_view_init(tensor); } } else { if (tensor->data == NULL) { @@ -768,8 +816,8 @@ static void lm_ggml_gallocr_init_tensor(lm_ggml_gallocr_t galloc, struct lm_ggml } } -static bool lm_ggml_gallocr_node_needs_realloc(lm_ggml_gallocr_t galloc, struct lm_ggml_tensor * node, struct node_alloc * nalloc, struct tensor_alloc * talloc) { - lm_ggml_backend_buffer_type_t buft = galloc->bufts[nalloc->buffer_id]; +static bool lm_ggml_gallocr_node_needs_realloc(lm_ggml_gallocr_t galloc, struct lm_ggml_tensor * node, struct tensor_alloc * talloc) { + lm_ggml_backend_buffer_type_t buft = talloc->buffer_id != -1 ? galloc->bufts[talloc->buffer_id] : NULL; size_t node_size = (node->data || node->view_src) ? 0 : lm_ggml_backend_buft_get_alloc_size(buft, node); return talloc->size_max >= node_size; } @@ -793,7 +841,7 @@ static bool lm_ggml_gallocr_needs_realloc(lm_ggml_gallocr_t galloc, struct lm_gg struct lm_ggml_tensor * node = graph->nodes[i]; struct node_alloc * node_alloc = &galloc->node_allocs[i]; - if (!lm_ggml_gallocr_node_needs_realloc(galloc, node, node_alloc, &node_alloc->dst)) { + if (!lm_ggml_gallocr_node_needs_realloc(galloc, node, &node_alloc->dst)) { #ifndef NDEBUG fprintf(stderr, "%s: node %s is not valid\n", __func__, node->name); #endif @@ -805,7 +853,7 @@ static bool lm_ggml_gallocr_needs_realloc(lm_ggml_gallocr_t galloc, struct lm_gg if (src == NULL) { continue; } - if (!lm_ggml_gallocr_node_needs_realloc(galloc, src, node_alloc, &node_alloc->src[j])) { + if (!lm_ggml_gallocr_node_needs_realloc(galloc, src, &node_alloc->src[j])) { #ifndef NDEBUG fprintf(stderr, "%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name); #endif @@ -846,7 +894,7 @@ bool lm_ggml_gallocr_alloc_graph(lm_ggml_gallocr_t galloc, struct lm_ggml_cgraph for (int i = 0; i < graph->n_leafs; i++) { struct lm_ggml_tensor * leaf = graph->leafs[i]; struct leaf_alloc * leaf_alloc = &galloc->leaf_allocs[i]; - lm_ggml_gallocr_init_tensor(galloc, leaf, leaf_alloc->buffer_id, &leaf_alloc->leaf); + lm_ggml_gallocr_init_tensor(galloc, leaf, &leaf_alloc->leaf); } // nodes for (int i = 0; i < graph->n_nodes; i++) { @@ -857,9 +905,9 @@ bool lm_ggml_gallocr_alloc_graph(lm_ggml_gallocr_t galloc, struct lm_ggml_cgraph if (src == NULL) { continue; } - lm_ggml_gallocr_init_tensor(galloc, src, node_alloc->buffer_id, &node_alloc->src[j]); + lm_ggml_gallocr_init_tensor(galloc, src, &node_alloc->src[j]); } - lm_ggml_gallocr_init_tensor(galloc, node, node_alloc->buffer_id, &node_alloc->dst); + lm_ggml_gallocr_init_tensor(galloc, node, &node_alloc->dst); } return true; @@ -871,6 +919,15 @@ size_t lm_ggml_gallocr_get_buffer_size(lm_ggml_gallocr_t galloc, int buffer_id) if (galloc->buffers[buffer_id] == NULL) { return 0; } + + for (int i = 0; i < buffer_id; i++) { + if (galloc->buffers[i] == galloc->buffers[buffer_id]) { + // this buffer is the same as a previous one due to the same buffer type being used multiple times + // only return the buffer size the first time it appears to avoid double counting + return 0; + } + } + return lm_ggml_backend_buffer_get_size(galloc->buffers[buffer_id]); } @@ -886,7 +943,7 @@ static bool alloc_tensor_range(struct lm_ggml_context * ctx, fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, lm_ggml_backend_buft_name(buft), size); #endif for (size_t i = 0; i < *n_buffers; i++) { - lm_ggml_backend_buffer_free(*buffers[i]); + lm_ggml_backend_buffer_free((*buffers)[i]); } free(*buffers); return false; @@ -899,12 +956,12 @@ static bool alloc_tensor_range(struct lm_ggml_context * ctx, if (t->view_src == NULL) { lm_ggml_tallocr_alloc(&tallocr, t); } else if (t->buffer == NULL) { - lm_ggml_backend_view_init(buffer, t); + lm_ggml_backend_view_init(t); } } else { if (t->view_src != NULL && t->buffer == NULL) { // view of a pre-allocated tensor - lm_ggml_backend_view_init(buffer, t); + lm_ggml_backend_view_init(t); } } } diff --git a/cpp/ggml-backend-impl.h b/cpp/ggml-backend-impl.h index 6971cde1..31eba9df 100644 --- a/cpp/ggml-backend-impl.h +++ b/cpp/ggml-backend-impl.h @@ -17,13 +17,15 @@ extern "C" { struct lm_ggml_backend_buffer_type_i { const char * (*LM_GGML_CALL get_name) (lm_ggml_backend_buffer_type_t buft); + // allocate a buffer of this type lm_ggml_backend_buffer_t (*LM_GGML_CALL alloc_buffer) (lm_ggml_backend_buffer_type_t buft, size_t size); - size_t (*LM_GGML_CALL get_alignment) (lm_ggml_backend_buffer_type_t buft); // tensor alignment - size_t (*LM_GGML_CALL get_max_size) (lm_ggml_backend_buffer_type_t buft); // allocation max size - size_t (*LM_GGML_CALL get_alloc_size) (lm_ggml_backend_buffer_type_t buft, const struct lm_ggml_tensor * tensor); // data size needed to allocate the tensor, including padding - bool (*LM_GGML_CALL supports_backend)(lm_ggml_backend_buffer_type_t buft, lm_ggml_backend_t backend); // check if the buffer type is usable by the backend + // tensor alignment + size_t (*LM_GGML_CALL get_alignment) (lm_ggml_backend_buffer_type_t buft); + // max buffer size that can be allocated + size_t (*LM_GGML_CALL get_max_size) (lm_ggml_backend_buffer_type_t buft); + // data size needed to allocate the tensor, including padding + size_t (*LM_GGML_CALL get_alloc_size) (lm_ggml_backend_buffer_type_t buft, const struct lm_ggml_tensor * tensor); // check if tensor data is in host memory - // should be equivalent to supports_backend(buft, lm_ggml_backend_cpu_init()) bool (*LM_GGML_CALL is_host) (lm_ggml_backend_buffer_type_t buft); }; @@ -92,27 +94,37 @@ extern "C" { void (*LM_GGML_CALL synchronize)(lm_ggml_backend_t backend); // compute graph with a plan (not used currently) + // create a new plan for a graph lm_ggml_backend_graph_plan_t (*LM_GGML_CALL graph_plan_create) (lm_ggml_backend_t backend, const struct lm_ggml_cgraph * cgraph); void (*LM_GGML_CALL graph_plan_free) (lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan); + // update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology + void (*LM_GGML_CALL graph_plan_update) (lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan, const struct lm_ggml_cgraph * cgraph); + // compute the graph with the plan + enum lm_ggml_status (*LM_GGML_CALL graph_plan_compute)(lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan); - // compute graph with a plan - enum lm_ggml_status (*LM_GGML_CALL graph_plan_compute)(lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan); // compute graph without a plan (async) enum lm_ggml_status (*LM_GGML_CALL graph_compute) (lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph); - // check if the backend supports an operation + // check if the backend can compute an operation bool (*LM_GGML_CALL supports_op)(lm_ggml_backend_t backend, const struct lm_ggml_tensor * op); + // check if the backend can use tensors allocated in a buffer type + bool (*LM_GGML_CALL supports_buft)(lm_ggml_backend_t backend, lm_ggml_backend_buffer_type_t buft); + // check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer // these should be expensive operations with large batch sizes that may benefit from running on this backend // even if the weight has to be copied from the CPU temporarily bool (*LM_GGML_CALL offload_op)(lm_ggml_backend_t backend, const struct lm_ggml_tensor * op); // (optional) event synchronization + // create a new event that can record events on this backend instance lm_ggml_backend_event_t (*LM_GGML_CALL event_new) (lm_ggml_backend_t backend); void (*LM_GGML_CALL event_free) (lm_ggml_backend_event_t event); + // record an event on the backend instance that created it void (*LM_GGML_CALL event_record) (lm_ggml_backend_event_t event); + // wait for an event on on a different backend instance void (*LM_GGML_CALL event_wait) (lm_ggml_backend_t backend, lm_ggml_backend_event_t event); + // block until an event is recorded void (*LM_GGML_CALL event_synchronize) (lm_ggml_backend_event_t event); }; diff --git a/cpp/ggml-backend.c b/cpp/ggml-backend.c index 1e0de870..a8fbd1a1 100644 --- a/cpp/ggml-backend.c +++ b/cpp/ggml-backend.c @@ -44,10 +44,6 @@ LM_GGML_CALL size_t lm_ggml_backend_buft_get_alloc_size(lm_ggml_backend_buffer_t return lm_ggml_nbytes(tensor); } -bool lm_ggml_backend_buft_supports_backend(lm_ggml_backend_buffer_type_t buft, lm_ggml_backend_t backend) { - return buft->iface.supports_backend(buft, backend); -} - bool lm_ggml_backend_buft_is_host(lm_ggml_backend_buffer_type_t buft) { if (buft->iface.is_host) { return buft->iface.is_host(buft); @@ -138,6 +134,10 @@ void lm_ggml_backend_buffer_set_usage(lm_ggml_backend_buffer_t buffer, enum lm_g } } +enum lm_ggml_backend_buffer_usage lm_ggml_backend_buffer_get_usage(lm_ggml_backend_buffer_t buffer) { + return buffer->usage; +} + lm_ggml_backend_buffer_type_t lm_ggml_backend_buffer_get_type(lm_ggml_backend_buffer_t buffer) { return buffer->buft; } @@ -151,7 +151,7 @@ void lm_ggml_backend_buffer_reset(lm_ggml_backend_buffer_t buffer) { bool lm_ggml_backend_buffer_copy_tensor(const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst) { lm_ggml_backend_buffer_t dst_buf = dst->view_src ? dst->view_src->buffer : dst->buffer; if (dst_buf->iface.cpy_tensor) { - return src->buffer->iface.cpy_tensor(dst_buf, src, dst); + return dst_buf->iface.cpy_tensor(dst_buf, src, dst); } return false; } @@ -286,6 +286,10 @@ bool lm_ggml_backend_supports_op(lm_ggml_backend_t backend, const struct lm_ggml return backend->iface.supports_op(backend, op); } +bool lm_ggml_backend_supports_buft(lm_ggml_backend_t backend, lm_ggml_backend_buffer_type_t buft) { + return backend->iface.supports_buft(backend, buft); +} + bool lm_ggml_backend_offload_op(lm_ggml_backend_t backend, const struct lm_ggml_tensor * op) { if (backend->iface.offload_op != NULL) { return backend->iface.offload_op(backend, op); @@ -394,7 +398,7 @@ void lm_ggml_backend_event_wait(lm_ggml_backend_t backend, lm_ggml_backend_event // backend registry -#define LM_GGML_REG_MAX_BACKENDS 16 +#define LM_GGML_REG_MAX_BACKENDS 64 struct lm_ggml_backend_reg { char name[128]; @@ -445,6 +449,11 @@ LM_GGML_CALL static void lm_ggml_backend_registry_init(void) { extern LM_GGML_CALL void lm_ggml_backend_kompute_reg_devices(void); lm_ggml_backend_kompute_reg_devices(); #endif + +#ifdef LM_GGML_USE_CANN + extern LM_GGML_CALL int lm_ggml_backend_cann_reg_devices(void); + lm_ggml_backend_cann_reg_devices(); +#endif } LM_GGML_CALL void lm_ggml_backend_register(const char * name, lm_ggml_backend_init_fn init_fn, lm_ggml_backend_buffer_type_t default_buffer_type, void * user_data) { @@ -639,12 +648,6 @@ LM_GGML_CALL static size_t lm_ggml_backend_cpu_buffer_type_get_alignment(lm_ggml LM_GGML_UNUSED(buft); } -LM_GGML_CALL static bool lm_ggml_backend_cpu_buffer_type_supports_backend(lm_ggml_backend_buffer_type_t buft, lm_ggml_backend_t backend) { - return lm_ggml_backend_is_cpu(backend); - - LM_GGML_UNUSED(buft); -} - LM_GGML_CALL static bool lm_ggml_backend_cpu_buffer_type_is_host(lm_ggml_backend_buffer_type_t buft) { return true; @@ -659,7 +662,6 @@ LM_GGML_CALL lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_buffer_type(void) /* .get_alignment = */ lm_ggml_backend_cpu_buffer_type_get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes - /* .supports_backend = */ lm_ggml_backend_cpu_buffer_type_supports_backend, /* .is_host = */ lm_ggml_backend_cpu_buffer_type_is_host, }, /* .context = */ NULL, @@ -715,7 +717,6 @@ lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_hbm_buffer_type(void) { /* .get_alignment = */ lm_ggml_backend_cpu_buffer_type_get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes - /* .supports_backend = */ lm_ggml_backend_cpu_buffer_type_supports_backend, /* .is_host = */ lm_ggml_backend_cpu_buffer_type_is_host, }, /* .context = */ NULL, @@ -836,6 +837,12 @@ LM_GGML_CALL static bool lm_ggml_backend_cpu_supports_op(lm_ggml_backend_t backe LM_GGML_UNUSED(backend); } +LM_GGML_CALL static bool lm_ggml_backend_cpu_supports_buft(lm_ggml_backend_t backend, lm_ggml_backend_buffer_type_t buft) { + return lm_ggml_backend_buft_is_host(buft); + + LM_GGML_UNUSED(backend); +} + static struct lm_ggml_backend_i cpu_backend_i = { /* .get_name = */ lm_ggml_backend_cpu_name, /* .free = */ lm_ggml_backend_cpu_free, @@ -846,9 +853,11 @@ static struct lm_ggml_backend_i cpu_backend_i = { /* .synchronize = */ NULL, /* .graph_plan_create = */ lm_ggml_backend_cpu_graph_plan_create, /* .graph_plan_free = */ lm_ggml_backend_cpu_graph_plan_free, + /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ lm_ggml_backend_cpu_graph_plan_compute, /* .graph_compute = */ lm_ggml_backend_cpu_graph_compute, /* .supports_op = */ lm_ggml_backend_cpu_supports_op, + /* .supports_buft = */ lm_ggml_backend_cpu_supports_buft, /* .offload_op = */ NULL, /* .event_new = */ NULL, /* .event_free = */ NULL, @@ -1055,6 +1064,9 @@ struct lm_ggml_backend_sched { int * node_backend_ids; // [graph_size] int * leaf_backend_ids; // [graph_size] + int * prev_node_backend_ids; // [graph_size] + int * prev_leaf_backend_ids; // [graph_size] + // copy of the graph with modified inputs struct lm_ggml_cgraph * graph; @@ -1075,6 +1087,8 @@ struct lm_ggml_backend_sched { lm_ggml_backend_sched_eval_callback callback_eval; void * callback_eval_user_data; + bool debug; + // align context_buffer to LM_GGML_MEM_ALIGN #ifdef _MSC_VER __declspec(align(LM_GGML_MEM_ALIGN)) @@ -1097,22 +1111,24 @@ static int lm_ggml_backend_sched_backend_id(lm_ggml_backend_sched_t sched, lm_gg return -1; } -static int lm_ggml_backend_sched_backend_from_buffer(lm_ggml_backend_sched_t sched, const struct lm_ggml_tensor * tensor) { +static int lm_ggml_backend_sched_backend_from_buffer(lm_ggml_backend_sched_t sched, const struct lm_ggml_tensor * tensor, const struct lm_ggml_tensor * op) { lm_ggml_backend_buffer_t buffer = tensor->buffer; if (buffer == NULL) { return -1; } - // find highest prio backend that supports the buffer type + // find highest prio backend that supports the buffer type and the op for (int i = 0; i < sched->n_backends; i++) { - if (lm_ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) { + if (lm_ggml_backend_supports_buft(sched->backends[i], buffer->buft) && + lm_ggml_backend_supports_op(sched->backends[i], op)) { return i; } } - fprintf(stderr, "%s: error: no backend supports buffer type %s used in tensor %s\n", - __func__, lm_ggml_backend_buffer_name(buffer), tensor->name); - LM_GGML_ASSERT(false); +#ifndef NDEBUG + fprintf(stderr, "%s: warning: no backend supports op %s with a weight with buffer type %s used in tensor %s, the weight will need to be copied\n", + __func__, lm_ggml_op_desc(tensor), lm_ggml_backend_buffer_name(buffer), tensor->name); +#endif return -1; } @@ -1131,7 +1147,7 @@ static int lm_ggml_backend_sched_backend_id_from_cur(lm_ggml_backend_sched_t sch // TODO: use supports_op to check if the backend supports the op // assign pre-allocated nodes to their backend - int cur_backend_id = lm_ggml_backend_sched_backend_from_buffer(sched, tensor); + int cur_backend_id = lm_ggml_backend_sched_backend_from_buffer(sched, tensor, tensor); if (cur_backend_id != -1) { SET_CAUSE(tensor, "1.dst"); return cur_backend_id; @@ -1139,7 +1155,7 @@ static int lm_ggml_backend_sched_backend_id_from_cur(lm_ggml_backend_sched_t sch // view_src if (tensor->view_src != NULL) { - cur_backend_id = lm_ggml_backend_sched_backend_from_buffer(sched, tensor->view_src); + cur_backend_id = lm_ggml_backend_sched_backend_from_buffer(sched, tensor->view_src, tensor); if (cur_backend_id != -1) { SET_CAUSE(tensor, "1.vsrc"); return cur_backend_id; @@ -1161,11 +1177,11 @@ static int lm_ggml_backend_sched_backend_id_from_cur(lm_ggml_backend_sched_t sch continue; } if (src->buffer != NULL && src->buffer->usage == LM_GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { - int src_backend_id = lm_ggml_backend_sched_backend_from_buffer(sched, src); + int src_backend_id = lm_ggml_backend_sched_backend_from_buffer(sched, src, tensor); // check if a backend with higher prio wants to offload the op if (src_backend_id == sched->n_backends - 1) { for (int b = 0; b < src_backend_id; b++) { - if (lm_ggml_backend_offload_op(sched->backends[b], tensor)) { + if (lm_ggml_backend_supports_op(sched->backends[b], tensor) && lm_ggml_backend_offload_op(sched->backends[b], tensor)) { SET_CAUSE(tensor, "1.off"); return b; } @@ -1223,10 +1239,33 @@ static void lm_ggml_backend_sched_print_assignments(lm_ggml_backend_sched_t sche } } -//#define DEBUG_PASS1 -//#define DEBUG_PASS2 -//#define DEBUG_PASS3 -//#define DEBUG_PASS4 +static bool lm_ggml_backend_sched_buffer_supported(lm_ggml_backend_sched_t sched, struct lm_ggml_tensor * t, int backend_id) { + lm_ggml_backend_buffer_t buf = t->view_src ? t->view_src->buffer : t->buffer; + lm_ggml_backend_buffer_type_t buft = NULL; + + if (buf) { + // the tensor is already allocated + buft = buf->buft; + } else { + // see if the tensor already has a backend assigned, and use the buffer type of that backend + int tensor_backend_id = tensor_backend_id(t); + if (tensor_backend_id == -1 && t->view_src) { + tensor_backend_id = tensor_backend_id(t->view_src); + } + if (tensor_backend_id != -1) { + buft = sched->bufts[tensor_backend_id]; + } + } + + return buft != NULL && lm_ggml_backend_supports_buft(sched->backends[backend_id], buft); +} + +static void lm_ggml_backend_sched_set_if_supported(lm_ggml_backend_sched_t sched, struct lm_ggml_tensor * node, int cur_backend_id, int * node_backend_id) { + if (lm_ggml_backend_supports_op(sched->backends[cur_backend_id], node)) { + *node_backend_id = cur_backend_id; + SET_CAUSE(node, "2.sup"); + } +} // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, struct lm_ggml_cgraph * graph) { @@ -1280,17 +1319,13 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str } } } -#ifdef DEBUG_PASS1 - fprintf(stderr, "PASS 1 ASSIGNMENTS\n"); lm_ggml_backend_sched_print_assignments(sched, graph); -#endif // pass 2: expand current backend assignments // assign the same backend to adjacent nodes // expand gpu backends (i.e. non last prio) up and down, ignoring cpu (the lowest priority backend) // thus, cpu will never be used unless weights are on cpu, or there are no gpu ops between cpu ops - - - // pass 2.2 expand gpu down + // ops unsupported by the backend being expanded will be left unassigned so that they can be assigned later when the locations of its inputs are known + // expand gpu down { int cur_backend_id = -1; for (int i = 0; i < graph->n_nodes; i++) { @@ -1306,13 +1341,12 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str } else { cur_backend_id = *node_backend_id; } - } else { - *node_backend_id = cur_backend_id; - SET_CAUSE(node, "2.2"); + } else if (cur_backend_id != -1) { + lm_ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id); } } } - // pass 2.1 expand gpu up + // expand gpu up { int cur_backend_id = -1; for (int i = graph->n_nodes - 1; i >= 0; i--) { @@ -1328,13 +1362,12 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str } else { cur_backend_id = *node_backend_id; } - } else { - *node_backend_id = cur_backend_id; - SET_CAUSE(node, "2.1"); + } else if (cur_backend_id != -1) { + lm_ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id); } } } - // pass 2.4 expand rest down + // expand rest down { int cur_backend_id = -1; for (int i = 0; i < graph->n_nodes; i++) { @@ -1345,13 +1378,12 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str int * node_backend_id = &tensor_backend_id(node); if (*node_backend_id != -1) { cur_backend_id = *node_backend_id; - } else { - *node_backend_id = cur_backend_id; - SET_CAUSE(node, "2.4"); + } else if (cur_backend_id != -1) { + lm_ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id); } } } - // pass 2.3 expand rest up + // expand rest up { int cur_backend_id = -1; for (int i = graph->n_nodes - 1; i >= 0; i--) { @@ -1362,24 +1394,80 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str int * node_backend_id = &tensor_backend_id(node); if (*node_backend_id != -1) { cur_backend_id = *node_backend_id; - } else { - *node_backend_id = cur_backend_id; - SET_CAUSE(node, "2.3"); + } else if (cur_backend_id != -1) { + lm_ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id); } } } -#ifdef DEBUG_PASS2 - fprintf(stderr, "PASS 2 ASSIGNMENTS\n"); lm_ggml_backend_sched_print_assignments(sched, graph); -#endif + // pass 3: upgrade nodes to higher prio backends with compatible buffer types + // if the tensor is already in the same buffer type (*) as another higher priority backend, we should move it there + // however, we also need to verify that the sources are in compatible buffer types + // (*) the actual requirement is more relaxed, the buffer type of the backend should be supported by all the users of this tensor further down the graph + // however, this is slow to verify, so we have a more strict requirement that the buffer type is the same + // this is not uncommon since multiple backends can use host memory, with the same buffer type (eg. BLAS and CPU) + // additionally, set remaining unassigned nodes to the backend with the most supported inputs + // only nodes that could not be assigned during expansion due to the backend not supporting the op should be unassigned at this point + for (int i = 0; i < graph->n_nodes; i++) { + struct lm_ggml_tensor * node = graph->nodes[i]; + if (lm_ggml_is_view_op(node->op)) { + continue; + } + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id == -1) { + // unassigned node: find the backend with the most supported inputs + int n_supported_best = -1; + for (int b = 0; b < sched->n_backends; b++) { + if (lm_ggml_backend_supports_op(sched->backends[b], node)) { + int n_supported = 0; + for (int j = 0; j < LM_GGML_MAX_SRC; j++) { + struct lm_ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + if ((tensor_backend_id(src) != -1 || tensor_backend_id(src->view_src) != -1) && lm_ggml_backend_sched_buffer_supported(sched, src, b)) { + n_supported++; + } + } + if (n_supported > n_supported_best) { + n_supported_best = n_supported; + *node_backend_id = b; + SET_CAUSE(node, "3.best"); + } + } + } + } else { + // assigned node: upgrade to higher prio backend if possible + for (int b = 0; b < *node_backend_id; b++) { + if (sched->bufts[b] == sched->bufts[*node_backend_id] && lm_ggml_backend_supports_op(sched->backends[b], node)) { + bool supported = true; + for (int j = 0; j < LM_GGML_MAX_SRC; j++) { + struct lm_ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + if (!lm_ggml_backend_sched_buffer_supported(sched, src, b)) { + supported = false; + break; + } + } + if (supported) { + *node_backend_id = b; + SET_CAUSE(node, "3.upg"); + break; + } + } + } + } + } - // pass 3: assign backends to remaining src from dst and view_src + // pass 4: assign backends to remaining src from dst and view_src for (int i = 0; i < graph->n_nodes; i++) { struct lm_ggml_tensor * node = graph->nodes[i]; int * cur_backend_id = &tensor_backend_id(node); if (node->view_src != NULL && *cur_backend_id == -1) { *cur_backend_id = tensor_backend_id(node->view_src); - SET_CAUSE(node, "3.vsrc"); + SET_CAUSE(node, "4.vsrc"); } for (int j = 0; j < LM_GGML_MAX_SRC; j++) { struct lm_ggml_tensor * src = node->src[j]; @@ -1391,17 +1479,14 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str if (src->view_src != NULL) { // views are always on the same backend as the source *src_backend_id = tensor_backend_id(src->view_src); - SET_CAUSE(src, "3.vsrc"); + SET_CAUSE(src, "4.vsrc"); } else { *src_backend_id = *cur_backend_id; - SET_CAUSE(src, "3.cur"); + SET_CAUSE(src, "4.cur"); } } } } -#ifdef DEBUG_PASS3 - fprintf(stderr, "PASS 3 ASSIGNMENTS\n"); lm_ggml_backend_sched_print_assignments(sched, graph); -#endif // pass 4: split graph, find tensors that need to be copied { @@ -1448,10 +1533,12 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str } } // check if the split has too many inputs + // FIXME: count the number of inputs instead of only checking when full if (split->n_inputs == LM_GGML_SCHED_MAX_SPLIT_INPUTS) { const size_t id = hash_id(src); int src_backend_id = sched->tensor_backend_id[id]; - if (src_backend_id != cur_backend_id && sched->tensor_copies[hash_id(src)][cur_backend_id][0] == NULL) { + bool supported = lm_ggml_backend_sched_buffer_supported(sched, src, cur_backend_id); + if (src_backend_id != cur_backend_id && sched->tensor_copies[hash_id(src)][cur_backend_id][0] == NULL && !supported) { //printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name); need_new_split = true; break; @@ -1486,7 +1573,7 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str const int src_backend_id = tensor_backend_id(src); assert(src_backend_id != -1); // all inputs should be assigned by now - if (src->flags & LM_GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) { + if (src->flags & LM_GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) { size_t id = hash_id(src); if (sched->tensor_copies[id][src_backend_id][0] == NULL) { lm_ggml_backend_t backend = sched->backends[src_backend_id]; @@ -1511,7 +1598,8 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str } } - if (src_backend_id != node_backend_id) { + bool supported = lm_ggml_backend_sched_buffer_supported(sched, src, cur_backend_id); + if (src_backend_id != cur_backend_id && !supported) { // create a copy of the input in the split's backend const size_t id = hash_id(src); if (sched->tensor_copies[id][cur_backend_id][0] == NULL) { @@ -1537,9 +1625,21 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str split->i_end = graph->n_nodes; sched->n_splits = i_split + 1; } -#ifdef DEBUG_PASS4 - fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); lm_ggml_backend_sched_print_assignments(sched, graph); -#endif + + if (sched->debug) { + lm_ggml_backend_sched_print_assignments(sched, graph); + } + + // swap node_backend_ids and leaf_backend_ids and prevs + { + int * tmp = sched->node_backend_ids; + sched->node_backend_ids = sched->prev_node_backend_ids; + sched->prev_node_backend_ids = tmp; + + tmp = sched->leaf_backend_ids; + sched->leaf_backend_ids = sched->prev_leaf_backend_ids; + sched->prev_leaf_backend_ids = tmp; + } // create copies of the graph for each split // TODO: avoid this copy @@ -1613,8 +1713,26 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str } static bool lm_ggml_backend_sched_alloc_splits(lm_ggml_backend_sched_t sched) { + bool backend_ids_changed = false; + for (int i = 0; i < sched->graph->n_nodes; i++) { + if (sched->node_backend_ids[i] != sched->prev_node_backend_ids[i] && + sched->bufts[sched->node_backend_ids[i]] != sched->bufts[sched->prev_node_backend_ids[i]]) { + backend_ids_changed = true; + break; + } + } + if (!backend_ids_changed) { + for (int i = 0; i < sched->graph->n_leafs; i++) { + if (sched->leaf_backend_ids[i] != sched->prev_leaf_backend_ids[i] && + sched->bufts[sched->leaf_backend_ids[i]] != sched->bufts[sched->prev_leaf_backend_ids[i]]) { + backend_ids_changed = true; + break; + } + } + } + // allocate graph - if (!lm_ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) { + if (backend_ids_changed || !lm_ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) { // the re-allocation may cause the split inputs to be moved to a different address lm_ggml_backend_sched_synchronize(sched); #ifndef NDEBUG @@ -1727,6 +1845,8 @@ lm_ggml_backend_sched_t lm_ggml_backend_sched_new( struct lm_ggml_backend_sched * sched = calloc(1, sizeof(struct lm_ggml_backend_sched)); + sched->debug = getenv("LM_GGML_SCHED_DEBUG") != NULL; + // initialize hash table sched->hash_set = lm_ggml_hash_set_new(graph_size); sched->tensor_backend_id = calloc(sched->hash_set.size, sizeof(sched->tensor_backend_id[0])); @@ -1735,6 +1855,8 @@ lm_ggml_backend_sched_t lm_ggml_backend_sched_new( const size_t nodes_size = graph_size + LM_GGML_SCHED_MAX_SPLITS*LM_GGML_SCHED_MAX_SPLIT_INPUTS*2; sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0])); sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0])); + sched->prev_node_backend_ids = calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0])); + sched->prev_leaf_backend_ids = calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0])); sched->n_backends = n_backends; @@ -1747,7 +1869,7 @@ lm_ggml_backend_sched_t lm_ggml_backend_sched_new( for (int b = 0; b < n_backends; b++) { sched->backends[b] = backends[b]; sched->bufts[b] = bufts ? bufts[b] : lm_ggml_backend_get_default_buffer_type(backends[b]); - LM_GGML_ASSERT(lm_ggml_backend_buft_supports_backend(sched->bufts[b], backends[b])); + LM_GGML_ASSERT(lm_ggml_backend_supports_buft(backends[b], sched->bufts[b])); if (sched->n_copies > 1) { for (int c = 0; c < sched->n_copies; c++) { sched->events[b][c] = lm_ggml_backend_event_new(backends[b]); @@ -1779,6 +1901,8 @@ void lm_ggml_backend_sched_free(lm_ggml_backend_sched_t sched) { free(sched->tensor_copies); free(sched->node_backend_ids); free(sched->leaf_backend_ids); + free(sched->prev_node_backend_ids); + free(sched->prev_leaf_backend_ids); free(sched); } @@ -1864,6 +1988,15 @@ int lm_ggml_backend_sched_get_n_copies(lm_ggml_backend_sched_t sched) { return sched->n_copies; } +int lm_ggml_backend_sched_get_n_backends(lm_ggml_backend_sched_t sched) { + return sched->n_backends; +} + +lm_ggml_backend_t lm_ggml_backend_sched_get_backend(lm_ggml_backend_sched_t sched, int i) { + LM_GGML_ASSERT(i >= 0 && i < sched->n_backends); + return sched->backends[i]; +} + size_t lm_ggml_backend_sched_get_buffer_size(lm_ggml_backend_sched_t sched, lm_ggml_backend_t backend) { int backend_index = lm_ggml_backend_sched_backend_id(sched, backend); LM_GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); @@ -1875,6 +2008,7 @@ void lm_ggml_backend_sched_set_tensor_backend(lm_ggml_backend_sched_t sched, str int backend_index = lm_ggml_backend_sched_backend_id(sched, backend); LM_GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); tensor_backend_id(node) = backend_index; + SET_CAUSE(node, "usr"); } lm_ggml_backend_t lm_ggml_backend_sched_get_tensor_backend(lm_ggml_backend_sched_t sched, struct lm_ggml_tensor * node) { @@ -1887,15 +2021,15 @@ lm_ggml_backend_t lm_ggml_backend_sched_get_tensor_backend(lm_ggml_backend_sched // utils -void lm_ggml_backend_view_init(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor) { +void lm_ggml_backend_view_init(struct lm_ggml_tensor * tensor) { LM_GGML_ASSERT(tensor->buffer == NULL); LM_GGML_ASSERT(tensor->view_src != NULL); LM_GGML_ASSERT(tensor->view_src->buffer != NULL); LM_GGML_ASSERT(tensor->view_src->data != NULL); - tensor->buffer = buffer; + tensor->buffer = tensor->view_src->buffer; tensor->data = (char *)tensor->view_src->data + tensor->view_offs; - lm_ggml_backend_buffer_init_tensor(buffer, tensor); + lm_ggml_backend_buffer_init_tensor(tensor->buffer, tensor); } void lm_ggml_backend_tensor_alloc(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, void * addr) { @@ -1954,7 +2088,7 @@ static void graph_copy_init_tensor(struct lm_ggml_hash_set hash_set, struct lm_g struct lm_ggml_tensor * dst = node_copies[id]; if (dst->view_src != NULL) { graph_copy_init_tensor(hash_set, node_copies, node_init, src->view_src); - lm_ggml_backend_view_init(dst->view_src->buffer, dst); + lm_ggml_backend_view_init(dst); } else { lm_ggml_backend_tensor_copy(src, dst); diff --git a/cpp/ggml-backend.h b/cpp/ggml-backend.h index f1b01127..e0177c2c 100644 --- a/cpp/ggml-backend.h +++ b/cpp/ggml-backend.h @@ -23,28 +23,29 @@ extern "C" { LM_GGML_API size_t lm_ggml_backend_buft_get_alignment (lm_ggml_backend_buffer_type_t buft); LM_GGML_API size_t lm_ggml_backend_buft_get_max_size (lm_ggml_backend_buffer_type_t buft); LM_GGML_API LM_GGML_CALL size_t lm_ggml_backend_buft_get_alloc_size (lm_ggml_backend_buffer_type_t buft, struct lm_ggml_tensor * tensor); - LM_GGML_API bool lm_ggml_backend_buft_supports_backend(lm_ggml_backend_buffer_type_t buft, lm_ggml_backend_t backend); LM_GGML_API bool lm_ggml_backend_buft_is_host (lm_ggml_backend_buffer_type_t buft); // buffer enum lm_ggml_backend_buffer_usage { LM_GGML_BACKEND_BUFFER_USAGE_ANY = 0, LM_GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1, + LM_GGML_BACKEND_BUFFER_USAGE_COMPUTE = 2, }; - LM_GGML_API const char * lm_ggml_backend_buffer_name (lm_ggml_backend_buffer_t buffer); - LM_GGML_API void lm_ggml_backend_buffer_free (lm_ggml_backend_buffer_t buffer); - LM_GGML_API void * lm_ggml_backend_buffer_get_base (lm_ggml_backend_buffer_t buffer); - LM_GGML_API size_t lm_ggml_backend_buffer_get_size (lm_ggml_backend_buffer_t buffer); - LM_GGML_API LM_GGML_CALL void lm_ggml_backend_buffer_init_tensor (lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor); - LM_GGML_API size_t lm_ggml_backend_buffer_get_alignment (lm_ggml_backend_buffer_t buffer); - LM_GGML_API size_t lm_ggml_backend_buffer_get_max_size (lm_ggml_backend_buffer_t buffer); - LM_GGML_API size_t lm_ggml_backend_buffer_get_alloc_size(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor); - LM_GGML_API void lm_ggml_backend_buffer_clear (lm_ggml_backend_buffer_t buffer, uint8_t value); - LM_GGML_API bool lm_ggml_backend_buffer_is_host (lm_ggml_backend_buffer_t buffer); - LM_GGML_API void lm_ggml_backend_buffer_set_usage (lm_ggml_backend_buffer_t buffer, enum lm_ggml_backend_buffer_usage usage); - LM_GGML_API lm_ggml_backend_buffer_type_t lm_ggml_backend_buffer_get_type (lm_ggml_backend_buffer_t buffer); - LM_GGML_API void lm_ggml_backend_buffer_reset (lm_ggml_backend_buffer_t buffer); + LM_GGML_API const char * lm_ggml_backend_buffer_name (lm_ggml_backend_buffer_t buffer); + LM_GGML_API void lm_ggml_backend_buffer_free (lm_ggml_backend_buffer_t buffer); + LM_GGML_API void * lm_ggml_backend_buffer_get_base (lm_ggml_backend_buffer_t buffer); + LM_GGML_API size_t lm_ggml_backend_buffer_get_size (lm_ggml_backend_buffer_t buffer); + LM_GGML_API LM_GGML_CALL void lm_ggml_backend_buffer_init_tensor (lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor); + LM_GGML_API size_t lm_ggml_backend_buffer_get_alignment (lm_ggml_backend_buffer_t buffer); + LM_GGML_API size_t lm_ggml_backend_buffer_get_max_size (lm_ggml_backend_buffer_t buffer); + LM_GGML_API size_t lm_ggml_backend_buffer_get_alloc_size(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor); + LM_GGML_API void lm_ggml_backend_buffer_clear (lm_ggml_backend_buffer_t buffer, uint8_t value); + LM_GGML_API bool lm_ggml_backend_buffer_is_host (lm_ggml_backend_buffer_t buffer); + LM_GGML_API void lm_ggml_backend_buffer_set_usage (lm_ggml_backend_buffer_t buffer, enum lm_ggml_backend_buffer_usage usage); + LM_GGML_API enum lm_ggml_backend_buffer_usage lm_ggml_backend_buffer_get_usage (lm_ggml_backend_buffer_t buffer); + LM_GGML_API lm_ggml_backend_buffer_type_t lm_ggml_backend_buffer_get_type (lm_ggml_backend_buffer_t buffer); + LM_GGML_API void lm_ggml_backend_buffer_reset (lm_ggml_backend_buffer_t buffer); // // Backend @@ -74,6 +75,7 @@ extern "C" { LM_GGML_API enum lm_ggml_status lm_ggml_backend_graph_compute (lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph); LM_GGML_API enum lm_ggml_status lm_ggml_backend_graph_compute_async(lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph); LM_GGML_API bool lm_ggml_backend_supports_op(lm_ggml_backend_t backend, const struct lm_ggml_tensor * op); + LM_GGML_API bool lm_ggml_backend_supports_buft(lm_ggml_backend_t backend, lm_ggml_backend_buffer_type_t buft); LM_GGML_API bool lm_ggml_backend_offload_op(lm_ggml_backend_t backend, const struct lm_ggml_tensor * op); // tensor copy between different backends @@ -90,7 +92,7 @@ extern "C" { LM_GGML_API void lm_ggml_backend_event_free (lm_ggml_backend_event_t event); LM_GGML_API void lm_ggml_backend_event_record (lm_ggml_backend_event_t event); LM_GGML_API void lm_ggml_backend_event_synchronize(lm_ggml_backend_event_t event); - LM_GGML_API void lm_ggml_backend_event_wait (lm_ggml_backend_t backend, lm_ggml_backend_event_t event); // wait async on event + LM_GGML_API void lm_ggml_backend_event_wait (lm_ggml_backend_t backend, lm_ggml_backend_event_t event); // // CPU backend @@ -119,7 +121,7 @@ extern "C" { LM_GGML_API size_t lm_ggml_backend_reg_get_count(void); LM_GGML_API size_t lm_ggml_backend_reg_find_by_name(const char * name); - LM_GGML_API lm_ggml_backend_t lm_ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params] + LM_GGML_API lm_ggml_backend_t lm_ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional) LM_GGML_API const char * lm_ggml_backend_reg_get_name(size_t i); LM_GGML_API lm_ggml_backend_t lm_ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific LM_GGML_API lm_ggml_backend_buffer_type_t lm_ggml_backend_reg_get_default_buffer_type(size_t i); @@ -182,6 +184,9 @@ extern "C" { // Initialize backend buffers from a measure graph LM_GGML_API bool lm_ggml_backend_sched_reserve(lm_ggml_backend_sched_t sched, struct lm_ggml_cgraph * measure_graph); + LM_GGML_API int lm_ggml_backend_sched_get_n_backends(lm_ggml_backend_sched_t sched); + LM_GGML_API lm_ggml_backend_t lm_ggml_backend_sched_get_backend(lm_ggml_backend_sched_t sched, int i); + // Get the number of splits of the last graph LM_GGML_API int lm_ggml_backend_sched_get_n_splits(lm_ggml_backend_sched_t sched); LM_GGML_API int lm_ggml_backend_sched_get_n_copies(lm_ggml_backend_sched_t sched); @@ -225,7 +230,7 @@ extern "C" { // Tensor initialization LM_GGML_API void lm_ggml_backend_tensor_alloc(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, void * addr); - LM_GGML_API void lm_ggml_backend_view_init(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor); + LM_GGML_API void lm_ggml_backend_view_init(struct lm_ggml_tensor * tensor); #ifdef __cplusplus diff --git a/cpp/ggml-common.h b/cpp/ggml-common.h index d70faf82..87065fc5 100644 --- a/cpp/ggml-common.h +++ b/cpp/ggml-common.h @@ -106,28 +106,34 @@ typedef sycl::half2 lm_ggml_half2; #define QR6_K 2 #define QI2_XXS (QK_K / (4*QR2_XXS)) -#define QR2_XXS 8 +#define QR2_XXS 4 #define QI2_XS (QK_K / (4*QR2_XS)) -#define QR2_XS 8 +#define QR2_XS 4 #define QI2_S (QK_K / (4*QR2_S)) -#define QR2_S 8 +#define QR2_S 4 #define QI3_XXS (QK_K / (4*QR3_XXS)) -#define QR3_XXS 8 +#define QR3_XXS 4 #define QI3_XS (QK_K / (4*QR3_XS)) -#define QR3_XS 8 +#define QR3_XS 4 #define QI1_S (QK_K / (4*QR1_S)) #define QR1_S 8 +#define QI1_M (QK_K / (4*QR1_M)) +#define QR1_M 8 + #define QI4_NL (QK4_NL / (4*QR4_NL)) #define QR4_NL 2 #define QI4_XS (QK_K / (4*QR4_XS)) -#define QR4_XS 8 +#define QR4_XS 2 + +#define QI3_S (QK_K / (4*QR3_S)) +#define QR3_S 4 #endif // LM_GGML_COMMON_DECL_CUDA || LM_GGML_COMMON_DECL_HIP @@ -193,6 +199,30 @@ typedef struct { } block_q8_1; static_assert(sizeof(block_q8_1) == 2*sizeof(lm_ggml_half) + QK8_1, "wrong q8_1 block size/padding"); +typedef struct { + lm_ggml_half d[4]; // deltas for 4 q4_0 blocks + uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks +} block_q4_0x4; +static_assert(sizeof(block_q4_0x4) == 4 * sizeof(lm_ggml_half) + QK4_0 * 2, "wrong q4_0x4 block size/padding"); + +typedef struct { + lm_ggml_half d[8]; // deltas for 8 q4_0 blocks + uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks +} block_q4_0x8; +static_assert(sizeof(block_q4_0x8) == 8 * sizeof(lm_ggml_half) + QK4_0 * 4, "wrong q4_0x8 block size/padding"); + +typedef struct { + lm_ggml_half d[4]; // deltas for 4 q8_0 blocks + int8_t qs[QK8_0 * 4]; // quants for 4 q8_0 blocks +} block_q8_0x4; +static_assert(sizeof(block_q8_0x4) == 4 * sizeof(lm_ggml_half) + QK8_0 * 4, "wrong q8_0x4 block size/padding"); + +typedef struct { + lm_ggml_half d[8]; // deltas for 8 q8_0 blocks + int8_t qs[QK8_0 * 8]; // quants for 8 q8_0 blocks +} block_q8_0x8; +static_assert(sizeof(block_q8_0x8) == 8 * sizeof(lm_ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding"); + // // Super-block quantization structures // diff --git a/cpp/ggml-impl.h b/cpp/ggml-impl.h index 0b0b0aa6..3d770449 100644 --- a/cpp/ggml-impl.h +++ b/cpp/ggml-impl.h @@ -17,7 +17,7 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) -#if defined(_WIN32) +#if defined(_MSC_VER) #define m512bh(p) p #define m512i(p) p @@ -609,6 +609,10 @@ static inline lm_ggml_fp16_t lm_ggml_compute_fp32_to_fp16(float f) { #endif // defined(__ARM_NEON) && (!defined(__MSC_VER) +#ifdef __ARM_FEATURE_SVE +#include +#endif // __ARM_FEATURE_SVE + // precomputed f32 table for f16 (256 KB) // defined in ggml.c, initialized in lm_ggml_init() extern float lm_ggml_table_f32_f16[1 << 16]; diff --git a/cpp/ggml-metal.h b/cpp/ggml-metal.h index 914dd6cf..b8d09095 100644 --- a/cpp/ggml-metal.h +++ b/cpp/ggml-metal.h @@ -1,7 +1,7 @@ // An interface allowing to compute lm_ggml_cgraph with Metal // // This is a fully functional interface that extends ggml with GPU support for Apple devices. -// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.) +// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.) // // How it works? // @@ -63,4 +63,3 @@ LM_GGML_API void lm_ggml_backend_metal_capture_next_compute(lm_ggml_backend_t ba #ifdef __cplusplus } #endif - diff --git a/cpp/ggml-metal.m b/cpp/ggml-metal.m index b6423cb4..8eb8263b 100644 --- a/cpp/ggml-metal.m +++ b/cpp/ggml-metal.m @@ -172,8 +172,10 @@ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, - LM_GGML_METAL_KERNEL_TYPE_ROPE_F32, - LM_GGML_METAL_KERNEL_TYPE_ROPE_F16, + LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, + LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, + LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, + LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16, LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32, LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, @@ -191,16 +193,16 @@ //LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, //LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 - LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, + LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, + LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, + LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, - LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, - LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, LM_GGML_METAL_KERNEL_TYPE_CONCAT, LM_GGML_METAL_KERNEL_TYPE_SQR, LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, @@ -626,8 +628,10 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); @@ -647,14 +651,14 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQR, sqr, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); @@ -731,6 +735,12 @@ static void lm_ggml_metal_free(struct lm_ggml_metal_context * ctx) { } static bool lm_ggml_metal_supports_op(const struct lm_ggml_metal_context * ctx, const struct lm_ggml_tensor * op) { + for (size_t i = 0, n = 3; i < n; ++i) { + if (op->src[i] != NULL && op->src[i]->type == LM_GGML_TYPE_BF16) { + return false; + } + } + switch (op->op) { case LM_GGML_OP_UNARY: switch (lm_ggml_get_unary_op(op)) { @@ -740,7 +750,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_metal_context * ctx, case LM_GGML_UNARY_OP_GELU: case LM_GGML_UNARY_OP_GELU_QUICK: case LM_GGML_UNARY_OP_SILU: - return true; + return lm_ggml_is_contiguous(op->src[0]); default: return false; } @@ -779,6 +789,12 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_metal_context * ctx, case LM_GGML_OP_LEAKY_RELU: return true; case LM_GGML_OP_FLASH_ATTN_EXT: + if (op->src[1]->type != LM_GGML_TYPE_F16) { + return false; + } + if (op->src[2]->type != LM_GGML_TYPE_F16) { + return false; + } if (op->src[0]->ne[0] == 256) { return false; } @@ -794,8 +810,8 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_metal_context * ctx, switch (op->src[0]->type) { case LM_GGML_TYPE_F32: switch (op->type) { - case LM_GGML_TYPE_F16: case LM_GGML_TYPE_F32: + case LM_GGML_TYPE_F16: case LM_GGML_TYPE_Q8_0: case LM_GGML_TYPE_Q4_0: case LM_GGML_TYPE_Q4_1: @@ -808,8 +824,8 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_metal_context * ctx, } case LM_GGML_TYPE_F16: switch (op->type) { - case LM_GGML_TYPE_F16: case LM_GGML_TYPE_F32: + case LM_GGML_TYPE_F16: return true; default: return false; @@ -821,7 +837,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_metal_context * ctx, case LM_GGML_OP_DIAG_MASK_INF: case LM_GGML_OP_GET_ROWS: { - return op->src[0]->type != LM_GGML_TYPE_BF16 && op->ne[3] == 1; + return op->ne[3] == 1; } default: return false; @@ -1519,7 +1535,6 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( { LM_GGML_ASSERT(ne00 == ne10); - // TODO: assert that dim2 and dim3 are contiguous LM_GGML_ASSERT(ne12 % ne02 == 0); LM_GGML_ASSERT(ne13 % ne03 == 0); @@ -1565,8 +1580,8 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) switch (src0->type) { - case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break; - case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break; + case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break; + case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break; default: break; } @@ -1771,10 +1786,6 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( } }; - if (lm_ggml_is_quantized(src0t)) { - LM_GGML_ASSERT(ne00 >= nth0*nth1); - } - [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -1853,9 +1864,10 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( // ne21 = n_rows const int dst_rows = ne20*ne21; const int dst_rows_min = n_as; + const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4; // max size of the rowids array in the kernel shared buffer - LM_GGML_ASSERT(dst_rows <= 2048); + LM_GGML_ASSERT(dst_rows <= dst_rows_max); // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel @@ -2187,6 +2199,7 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( case LM_GGML_OP_RMS_NORM: { LM_GGML_ASSERT(ne00 % 4 == 0); + LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0)); float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -2214,6 +2227,7 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( case LM_GGML_OP_GROUP_NORM: { LM_GGML_ASSERT(ne00 % 4 == 0); + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); //float eps; //memcpy(&eps, dst->op_params, sizeof(float)); @@ -2247,6 +2261,8 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( } break; case LM_GGML_OP_NORM: { + LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0)); + float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -2276,7 +2292,7 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; float freq_base; float freq_scale; @@ -2293,22 +2309,23 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); const bool is_neox = mode & 2; - const bool is_glm = mode & 4; - LM_GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal"); + id pipeline = nil; if (!is_neox) { - LM_GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox"); + switch (src0->type) { + case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + default: LM_GGML_ASSERT(false); + }; + } else { + switch (src0->type) { + case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: LM_GGML_ASSERT(false); + }; } - id pipeline = nil; - - switch (src0->type) { - case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break; - case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break; - default: LM_GGML_ASSERT(false); - }; - [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -2336,14 +2353,13 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; - [encoder setBytes:&mode length:sizeof( int) atIndex:22]; - [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:24]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29]; + [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; + [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; + [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; + [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; + [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; + [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; + [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -2755,8 +2771,8 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( LM_GGML_ASSERT(ne0 % lm_ggml_blck_size(dst->type) == 0); switch (dstt) { - case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; - case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; @@ -2769,8 +2785,8 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute( case LM_GGML_TYPE_F16: { switch (dstt) { - case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; - case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; + case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; + case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; default: LM_GGML_ASSERT(false && "not implemented"); }; } break; @@ -3031,12 +3047,6 @@ LM_GGML_CALL static size_t lm_ggml_backend_metal_buffer_type_get_max_size(lm_ggm UNUSED(buft); } -LM_GGML_CALL static bool lm_ggml_backend_metal_buffer_type_supports_backend(lm_ggml_backend_buffer_type_t buft, lm_ggml_backend_t backend) { - return lm_ggml_backend_is_metal(backend) || lm_ggml_backend_is_cpu(backend); - - UNUSED(buft); -} - LM_GGML_CALL static bool lm_ggml_backend_metal_buffer_type_is_host(lm_ggml_backend_buffer_type_t buft) { return true; @@ -3051,7 +3061,6 @@ LM_GGML_CALL lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(voi /* .get_alignment = */ lm_ggml_backend_metal_buffer_type_get_alignment, /* .get_max_size = */ lm_ggml_backend_metal_buffer_type_get_max_size, /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes - /* .supports_backend = */ lm_ggml_backend_metal_buffer_type_supports_backend, /* .is_host = */ lm_ggml_backend_metal_buffer_type_is_host, }, /* .context = */ NULL, @@ -3166,6 +3175,12 @@ LM_GGML_CALL static bool lm_ggml_backend_metal_supports_op(lm_ggml_backend_t bac return lm_ggml_metal_supports_op(metal_ctx, op); } +LM_GGML_CALL static bool lm_ggml_backend_metal_supports_buft(lm_ggml_backend_t backend, lm_ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == lm_ggml_backend_metal_buffer_type_get_name; + + UNUSED(backend); +} + static struct lm_ggml_backend_i lm_ggml_backend_metal_i = { /* .get_name = */ lm_ggml_backend_metal_name, /* .free = */ lm_ggml_backend_metal_free, @@ -3176,9 +3191,11 @@ LM_GGML_CALL static bool lm_ggml_backend_metal_supports_op(lm_ggml_backend_t bac /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ lm_ggml_backend_metal_graph_compute, /* .supports_op = */ lm_ggml_backend_metal_supports_op, + /* .supports_buft = */ lm_ggml_backend_metal_supports_buft, /* .offload_op = */ NULL, /* .event_new = */ NULL, /* .event_free = */ NULL, diff --git a/cpp/ggml-metal.metal b/cpp/ggml-metal.metal new file mode 100644 index 00000000..b16f2b7e --- /dev/null +++ b/cpp/ggml-metal.metal @@ -0,0 +1,6520 @@ +#define GGML_COMMON_DECL_METAL +#define GGML_COMMON_IMPL_METAL +#include "ggml-common.h" + +#include + +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) +#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } + +#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 + +enum ggml_sort_order { + GGML_SORT_ORDER_ASC, + GGML_SORT_ORDER_DESC, +}; + +// general-purpose kernel for addition, multiplication and division of two tensors +// pros: works for non-contiguous tensors, supports broadcast across all dims +// cons: not very efficient +kernel void kernel_add( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int64_t & offs, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10)); + } +} + +kernel void kernel_mul( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10)); + } +} + +kernel void kernel_div( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10)); + } +} + +template +kernel void kernel_repeat( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3 % ne03; + const int64_t i02 = i2 % ne02; + const int64_t i01 = i1 % ne01; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i00 = i0 % ne00; + *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00)); + } +} + +typedef decltype(kernel_repeat) kernel_repeat_t; + +template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] + src1[tpig % nb]; +} + +kernel void kernel_mul_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src1[tpig % nb]; +} + +kernel void kernel_div_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] / src1[tpig % nb]; +} + +kernel void kernel_scale( + device const float * src0, + device float * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_scale_4( + device const float4 * src0, + device float4 * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_clamp( + device const float * src0, + device float * dst, + constant float & min, + constant float & max, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); +} + +kernel void kernel_relu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = max(0.0f, src0[tpig]); +} + +kernel void kernel_sigmoid( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); +} + +kernel void kernel_tanh( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = precise::tanh(x); +} + +constant float GELU_COEF_A = 0.044715f; +constant float GELU_QUICK_COEF = -1.702f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +kernel void kernel_gelu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + // BEWARE !!! + // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! + // This was observed with Falcon 7B and 40B models + // + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_quick( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_silu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_silu_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_sqr( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src0[tpig]; +} + +kernel void kernel_sum_rows( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tpig[[thread_position_in_grid]]) { + int64_t i3 = tpig.z; + int64_t i2 = tpig.y; + int64_t i1 = tpig.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); + device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + float row_sum = 0; + + for (int64_t i0 = 0; i0 < ne00; i0++) { + row_sum += src_row[i0]; + } + + dst_row[0] = row_sum; +} + +template +kernel void kernel_soft_max( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const int64_t h = i02; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + + // find the max value in the block + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float lsum = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); + lsum += exp_psrc0; + pdst[i00] = exp_psrc0; + } + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + pdst[i00] *= inv_sum; + } +} + +template +kernel void kernel_soft_max_4( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + + float slope = 1.0f; + + if (max_bias > 0.0f) { + const int64_t h = i02; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); + } + + const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); + + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + + const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + pdst4[i00] *= inv_sum; + } +} + +typedef decltype(kernel_soft_max) kernel_soft_max_t; +typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t; + +template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; +template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; + +kernel void kernel_diag_mask_inf( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + const int64_t i02 = tpig[2]; + const int64_t i01 = tpig[1]; + const int64_t i00 = tpig[0]; + + if (i00 > n_past + i01) { + dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; + } else { + dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + device const float4 * src0, + device float4 * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + + const int64_t i = 2*tpig[0]; + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int64_t i4 = 4*i; + const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; + const int64_t i01 = i4/(ne00); i4 -= i01*ne00; + const int64_t i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + dst[i+1][k] = -INFINITY; + if (i00 + k > n_past + i01) { + dst[i][k] = -INFINITY; + } + } +} + +kernel void kernel_norm( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, + threadgroup float * sum [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); + // MEAN + // parallel sum + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + sum[tpitg] += x[i00]; + } + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg/2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + const float mean = sum[0] / ne00; + + // recenter and VARIANCE + threadgroup_barrier(mem_flags::mem_threadgroup); + device float * y = dst + tgpig*ne00; + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = x[i00] - mean; + sum[tpitg] += y[i00] * y[i00]; + } + + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg/2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + const float variance = sum[0] / ne00; + + const float scale = 1.0f/sqrt(variance + eps); + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = y[i00] * scale; + } +} + +kernel void kernel_rms_norm( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + + float4 sumf = 0; + float all_sum = 0; + + // parallel sum + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; + all_sum = simd_sum(all_sum); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = all_sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + all_sum = buf[tiisg]; + all_sum = simd_sum(all_sum); + } + + const float mean = all_sum/ne00; + const float scale = 1.0f/sqrt(mean + eps); + + device float4 * y = (device float4 *) (dst + tgpig*ne00); + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + y[i00] = x[i00] * scale; + } +} + +kernel void kernel_group_norm( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int32_t & n_groups, + constant float & eps, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t ne = ne00*ne01*ne02; + const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups); + + int start = tgpig * gs; + int end = start + gs; + + start += tpitg; + + if (end >= ne) { + end = ne; + } + + float tmp = 0.0f; // partial sum for thread in warp + + for (int j = start; j < end; j += ntg) { + tmp += src0[j]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float mean = tmp / gs; + tmp = 0.0f; + + for (int j = start; j < end; j += ntg) { + float xi = src0[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float variance = tmp / gs; + const float scale = 1.0f/sqrt(variance + eps); + for (int j = start; j < end; j += ntg) { + dst[j] *= scale; + } +} + +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (sumy * -8.f + acc[0] + acc[1]); +} + +// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + +// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (sumy * -16.f + acc[0] + acc[1]); +} + +// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_1/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + +// putting them in the kernel cause a significant performance penalty +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +//Note: This is a template, but strictly speaking it only applies to +// quantizations where the block size is 32. It also does not +// guard against the number of rows not being divisible by +// N_DST, so this is another explicit assumption of the implementation. +template +void mul_vec_q_n_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, uint tiisg, uint sgitg) { + const int nb = ne00/QK4_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q_type * x = (device const block_q_type *) src0 + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; // src1 vector cache + float sumf[nr] = {0.f}; + + const int ix = (tiisg/2); + const int il = (tiisg%2)*8; + + device const float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { + float sumy = 0; + for (int i = 0; i < 8; i += 2) { + sumy += yb[i] + yb[i+1]; + yl[i+0] = yb[i+ 0]; + yl[i+1] = yb[i+ 1]/256.f; + + sumy += yb[i+16] + yb[i+17]; + yl[i+8] = yb[i+16]/16.f; + yl[i+9] = yb[i+17]/4096.f; + } + + for (int row = 0; row < nr; row++) { + sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); + } + + yb += QK4_0 * 16; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; + } + } +} + +kernel void kernel_mul_mv_q4_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q4_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q5_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q5_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); +} + + +#define NB_Q8_0 8 + +void kernel_mul_mv_q8_0_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + const int nr = N_DST; + const int nsg = N_SIMDGROUP; + const int nw = N_SIMDWIDTH; + + const int nb = ne00/QK8_0; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[NB_Q8_0]; + float sumf[nr]={0.f}; + + const int ix = tiisg/4; + const int il = tiisg%4; + + device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; + + // each thread in a SIMD group deals with NB_Q8_0 quants at a time + for (int ib = ix; ib < nb; ib += nw/4) { + for (int i = 0; i < NB_Q8_0; ++i) { + yl[i] = yb[i]; + } + + for (int row = 0; row < nr; row++) { + device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; + float sumq = 0.f; + for (int iq = 0; iq < NB_Q8_0; ++iq) { + sumq += qs[iq] * yl[iq]; + } + sumf[row] += sumq*x[ib+row*nb].d; + } + + yb += NB_Q8_0 * nw; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q8_0_f32")]] +kernel void kernel_mul_mv_q8_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); +} + +#define N_F32_F32 4 + +void kernel_mul_mv_f32_f32_impl( + device const char * src0, + device const char * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg) { + + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_F32_F32; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const float * x = (device const float *) (src0 + offset0); + + if (ne00 < 128) { + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const float4 * x4 = (device const float4 *)x; + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float4 * y4 = (device const float4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +[[host_name("kernel_mul_mv_f32_f32")]] +kernel void kernel_mul_mv_f32_f32( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} + +#define N_F16_F16 4 + +kernel void kernel_mul_mv_f16_f16( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_F16_F16; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const half * x = (device const half *) (src0 + offset0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (half) x[i] * (half) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const half4 * x4 = (device const half4 *)x; + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); + device const half4 * y4 = (device const half4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +void kernel_mul_mv_f16_f32_1row_impl( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const half * x = (device const half *) (src0 + offset0); + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + if (ne00 < 128) { + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + device const half4 * x4 = (device const half4 *) x; + device const float4 * y4 = (device const float4 *) y; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_f16_f32_1row")]] +kernel void kernel_mul_mv_f16_f32_1row( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} + +#define N_F16_F32 4 + +void kernel_mul_mv_f16_f32_impl( + device const char * src0, + device const char * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg) { + + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_F16_F32; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const half * x = (device const half *) (src0 + offset0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const half4 * x4 = (device const half4 *)x; + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float4 * y4 = (device const float4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +[[host_name("kernel_mul_mv_f16_f32")]] +kernel void kernel_mul_mv_f16_f32( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} + +// Assumes row size (ne00) is a multiple of 4 +kernel void kernel_mul_mv_f16_f32_l4( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int nrows = ne11; + const int64_t r0 = tgpig.x; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const half4 * x4 = (device const half4 *) (src0 + offset0); + + for (int r1 = 0; r1 < nrows; ++r1) { + device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + thread float * cos_theta, thread float * sin_theta +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); + } + *cos_theta = cos(theta) * mscale; + *sin_theta = sin(theta) * mscale; +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { + return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +} + +static void rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { + // start and end correction dims + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); +} + +typedef void (rope_t)( + device const void * src0, + device const int32_t * src1, + device const float * src2, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant int & n_orig_ctx, + constant float & freq_base, + constant float & freq_scale, + constant float & ext_factor, + constant float & attn_factor, + constant float & beta_fast, + constant float & beta_slow, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]); + +template +kernel void kernel_rope( + device const void * src0, + device const int32_t * src1, + device const float * src2, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant int & n_orig_ctx, + constant float & freq_base, + constant float & freq_scale, + constant float & ext_factor, + constant float & attn_factor, + constant float & beta_fast, + constant float & beta_slow, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int64_t i3 = tgpig[2]; + const int64_t i2 = tgpig[1]; + const int64_t i1 = tgpig[0]; + + const bool is_neox = mode & 2; + + float corr_dims[2]; + rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + + device const int32_t * pos = src1; + + const int64_t p = pos[i2]; + + const float theta_0 = (float)p; + const float inv_ndims = -1.f/n_dims; + + if (!is_neox) { + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + + const float theta = theta_0 * pow(freq_base, inv_ndims*i0); + float cos_theta, sin_theta; + rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const T x0 = src[0]; + const T x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } + } else { + for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) { + if (ic < n_dims) { + const int64_t ib = 0; + + // simplified from `(ib * n_dims + ic) * inv_ndims` + const float cur_rot = inv_ndims*ic - ib; + const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f; + + const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor; + + float cos_theta, sin_theta; + rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); + + const int64_t i0 = ib*n_dims + ic/2; + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + const int64_t i0 = ic; + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } + } +} + +template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; +template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; + +typedef void (im2col_t)( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0; + const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1; + + const int32_t offset_dst = + (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + + (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + pdst[offset_dst] = 0.0f; + } else { + const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1; + pdst[offset_dst] = x[offset_src + iih * IW + iiw]; + } +} + +template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; +template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; + +kernel void kernel_upscale_f32( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant float & sf0, + constant float & sf1, + constant float & sf2, + constant float & sf3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3/sf3; + const int64_t i02 = i2/sf2; + const int64_t i01 = i1/sf1; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int64_t i00 = i0/sf0; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_ptr[0] = src0_ptr[0]; + } +} + +kernel void kernel_pad_f32( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + + if (i1 < ne01 && i2 < ne02 && i3 < ne03) { + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i0 < ne00) { + dst_ptr[i0] = src0_ptr[i0]; + } else { + dst_ptr[i0] = 0.0f; + } + } + + return; + } + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = 0.0f; + } +} + +kernel void kernel_arange_f32( + device char * dst, + constant int64_t & ne0, + constant float & start, + constant float & step, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + device float * dst_ptr = (device float *) dst; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = start + step * i0; + } +} + +kernel void kernel_timestep_embedding_f32( + device const char * src0, + device char * dst, + constant uint64_t & nb1, + constant int & dim, + constant int & max_period, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + int i = tgpig.x; + device float * embed_data = (device float *)(dst + i*nb1); + + int half_ = dim / 2; + for (int j = tpitg.x; j < half_; j += ntg.x) { + float timestep = ((device float *)src0)[i]; + float freq = (float)exp(-log((float)max_period) * j / half_); + float arg = timestep * freq; + embed_data[j ] = cos(arg); + embed_data[j + half_] = sin(arg); + } + + if (dim % 2 != 0 && tpitg.x == 0) { + embed_data[dim] = 0.f; + } +} + +// bitonic sort implementation following the CUDA kernels as reference +typedef void (argsort_t)( + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +template +kernel void kernel_argsort_f32_i32( + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + // bitonic sort + int col = tpitg[0]; + int row = tgpig[1]; + + if (col >= ncols_pad) return; + + device const float * x_row = x + row * ncols; + threadgroup int32_t * dst_row = shared_values; + + // initialize indices + dst_row[col] = col; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 2; k <= ncols_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } +} + +template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; + +kernel void kernel_leaky_relu_f32( + device const float * src0, + device float * dst, + constant float & slope, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; +} + +typedef void (flash_attn_ext_f16_t)( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, + constant int64_t & ne1, + constant int64_t & ne2, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup half * shared, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]); + +// ref: https://arxiv.org/pdf/2307.08691.pdf +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, + constant int64_t & ne1, + constant int64_t & ne2, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]*Q; + + const short D4 = D/4; + const short D8 = D/8; + //const short Q8 = Q/8; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + const short TF = T/2; // shared memory size per query in (float) + const short T4 = T/4; // shared memory size per query in (half4) + + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + simdgroup_half8x8 lo[D8]; + + // load heads from Q to shared memory + for (short j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 + j < ne01) { + sq4[j*T4 + i] = (half4) q4[i]; + } else { + sq4[j*T4 + i] = 0.0h; + } + } + } + + // zero out lo + for (short i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix(0.0h); + } + + // zero out shared memory SH + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < SH; i += NW) { + ss[j*TF + i] = 0.0f; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S[Q] = { [0 ... Q-1] = 0.0h }; + float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2/rk2; + const short ik3 = iq3/rk3; + + // v indices + const short iv2 = iq2/rv2; + const short iv3 = iq3/rv3; + + // load the queries from shared memory into local memory + simdgroup_half8x8 mq[D8]; + + for (short i = 0; i < D8; ++i) { + simdgroup_load(mq[i], sq + i*8, T); + } + + // pointer to the mask + device const half * mp = (device const half *) (mask + iq1*nb31); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exph); + } + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { + for (short cc = 0; cc < C/8; ++cc) { + simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); + + device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose + + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + } + + simdgroup_store(mqk, ss + 8*cc, TF, 0, false); + + const short tx = tiisg%4; + const short ty = tiisg/4; + + if (mask != q) { + // mqk = mqk*scale + mask*slope + ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; + ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; + } else { + // mqk = mqk*scale + ss[8*cc + ty*TF + 2*tx + 0] *= scale; + ss[8*cc + ty*TF + 2*tx + 1] *= scale; + } + } + } + + // used to detect blocks full of -INF + float smax = -INFINITY; + + // online softmax + { + float ms[Q]; + + for (short j = 0; j < Q; ++j) { + const short p = tiisg; + + const float m = M[j]; + const float s = ss[j*TF + p]; + + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + + ms[j] = exp(m - M[j]); + const float vs = exp(s - M[j]); + + S[j] = S[j]*ms[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*TF + p] = vs; + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*TF + C + tiisg] = ms[tiisg]; + } + } + + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } + + // O = diag(ms)*O + { + simdgroup_float8x8 mm; + simdgroup_load(mm, ss + C, TF, 0, false); + + for (short i = 0; i < D8; ++i) { + simdgroup_multiply(lo[i], mm, lo[i]); + } + } + + // O = O + (Q*K^T)*V + { + for (short cc = 0; cc < C/8; ++cc) { + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); + + simdgroup_float8x8 mv; + simdgroup_load(mv, ss + 8*cc, TF, 0, false); + + simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (short j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*TF + 0] = S[j]; + ss[j*TF + 1] = M[j]; + } + } + } + + // reduce the warps sequentially + for (short sg = 1; sg < nsg; ++sg) { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // each simdgroup stores its output to shared memory, reusing sq + if (sgitg == sg) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { + for (short j = 0; j < Q; ++j) { + const float S0 = ss[j*TF + 0]; + const float S1 = ss[j*TF + sg*SH + 0]; + + const float M0 = ss[j*TF + 1]; + const float M1 = ss[j*TF + sg*SH + 1]; + + M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*TF + 0] = S; + ss[j*TF + 1] = M; + + ss[j*TF + C + j ] = ms0; + ss[j*TF + C + j + sg*SH] = ms1; + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + { + simdgroup_half8x8 t; + simdgroup_float8x8 ms0; + simdgroup_float8x8 ms1; + + simdgroup_load(ms0, ss + C, TF, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); + + for (short i = 0; i < D8; ++i) { + simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_multiply(t, ms1, t); + + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + } + } + } + } + + // store result to shared memory (reuse sq) + if (sgitg == 0) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + for (short j = 0; j < Q && iq1 + j < ne01; ++j) { + const float S = ss[j*TF + 0]; + + for (short i = tiisg; i < D4; i += NW) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; + } + } + } +} + +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; + +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_vec_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, + constant int64_t & ne1, + constant int64_t & ne2, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]; + + const short D4 = D/4; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + half4 lo[D4/NW]; + + // load heads from Q to shared memory + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 < ne01) { + sq4[i] = (half4) q4[i]; + } else { + sq4[i] = 0.0h; + } + } + + // zero out lo + for (short i = tiisg; i < D4; i += NW) { + lo[i/NW] = 0.0h; + } + + // zero out shared memory SH + for (short i = tiisg; i < SH/4; i += NW) { + ss4[i] = 0.0h; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; + + // v indices + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half4 mq[D4]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + mq[i] = sq4[i]; + } + + // pointer to the mask + device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + float4 mqk = { 0.0h }; + + device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + half4x4 mk; + mk[0] = pk4[i + 0*(nb11/8)]; + mk[1] = pk4[i + 1*(nb11/8)]; + mk[2] = pk4[i + 2*(nb11/8)]; + mk[3] = pk4[i + 3*(nb11/8)]; + + mqk += (float4) (mq[i] * mk); + } + + // reduce the results from the threads in the simdgroup + mqk += simd_shuffle_down(mqk, 16); + mqk += simd_shuffle_down(mqk, 8); + mqk += simd_shuffle_down(mqk, 4); + mqk += simd_shuffle_down(mqk, 2); + mqk += simd_shuffle_down(mqk, 1); + + // mqk = mqk*scale + mask*slope + if (tiisg == 0) { + mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f); + + ss4[cc] = mqk; + } + } + } + + // online softmax + { + const short p = tiisg; + + const float m = M; + const float s = ss[p]; + + M = simd_max(max(M, s)); + + const float ms = exp(m - M); + const float vs = exp(s - M); + + S = S*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[p] = vs; + + // O = diag(ms)*O +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + lo[i/NW] *= ms; + } + } + + // O = O + (Q*K^T)*V + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; + lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; + lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; + lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + } + } + } + + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + } + + // store results to shared memory + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = lo[ii/NW]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + const float S0 = ss[ 0]; + const float S1 = ss[r*SH + 0]; + + const float M0 = ss[ 1]; + const float M1 = ss[r*SH + 1]; + + const float M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + const float S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + const float S = ss[0]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; + } + } +} + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; +//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; + +kernel void kernel_cpy_f16_f16( + device const half * src0, + device half * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f16_f32( + device const half * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f16( + device const float * src0, + device half * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f32( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_q8_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0; + + device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = src[j]; + amax = MAX(amax, fabs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK8_0].d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = src[j]*id; + + dst_data[i00/QK8_0].qs[j] = round(x0); + } + } +} + +kernel void kernel_cpy_f32_q4_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0; + + device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK4_0].d = d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_0/2 + j]*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); + + dst_data[i00/QK4_0].qs[j] = xi0; + dst_data[i00/QK4_0].qs[j] |= xi1 << 4; + } + } +} + +kernel void kernel_cpy_f32_q4_1( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1; + + device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < QK4_1; j++) { + const float v = src[j]; + if (min > v) min = v; + if (max < v) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK4_1].d = d; + dst_data[i00/QK4_1].m = min; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK4_1/2 + j] - min)*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); + + dst_data[i00/QK4_1].qs[j] = xi0; + dst_data[i00/QK4_1].qs[j] |= xi1 << 4; + } + } +} + +kernel void kernel_cpy_f32_q5_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0; + + device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK5_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK5_0].d = d; + + uint32_t qh = 0; + for (int j = 0; j < QK5_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK5_0/2 + j]*id; + + const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); + + dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + for (int j = 0; j < 4; ++j) { + dst_data[i00/QK5_0].qh[j] = qh8[j]; + } + } +} + +kernel void kernel_cpy_f32_q5_1( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1; + + device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float max = src[0]; + float min = src[0]; + + for (int j = 1; j < QK5_1; j++) { + const float v = src[j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK5_1].d = d; + dst_data[i00/QK5_1].m = min; + + uint32_t qh = 0; + for (int j = 0; j < QK5_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK5_1/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); + } + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + for (int j = 0; j < 4; ++j) { + dst_data[i00/QK5_1].qh[j] = qh8[j]; + } + } +} + +static inline int best_index_int8(int n, constant float * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + +constexpr constant static float kvalues_iq4nl_f[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +kernel void kernel_cpy_f32_iq4_nl( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL; + + device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / kvalues_iq4nl_f[0]; + const float id = d ? 1.0f/d : 0.0f; + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_NL/2 + j]*id; + + const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0); + const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1); + + dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4); + + const float v0 = kvalues_iq4nl_f[xi0]; + const float v1 = kvalues_iq4nl_f[xi1]; + const float w0 = src[0 + j]*src[0 + j]; + const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j]; + sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + + } + + dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d; + + } +} + +kernel void kernel_concat( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int32_t & dim, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); + + device const float * x; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + } else { + x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); + } + + device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = *x; + } +} + +void kernel_mul_mv_q2_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int step = sizeof(block_q2_K) * nb; + + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int iq = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + const int is = (8*ir)/16;// 0 or 1 + + device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; + + for (int ib = ix; ib < nb; ib += 4) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; + } + + device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + float dall = dh[0]; + float dmin = dh[1] * 1.f/16.f; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); + + qs += step/2; + sc += step; + dh += step/2; + } + + y4 += 4 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_q2_K_f32")]] +kernel void kernel_mul_mv_q2_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_q3_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + + //const uint16_t kmask1 = 0x3030; + //const uint16_t kmask2 = 0x0f0f; + + const int tid = tiisg/4; + const int ix = tiisg%4; + const int ip = tid/4; // 0 or 1 + const int il = 2*((tid%4)/2); // 0 or 2 + const int ir = tid%2; + const int n = 8; + const int l0 = n*ir; + + // One would think that the Metal compiler would figure out that ip and il can only have + // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it + // with these two tales. + // + // Possible masks for the high bit + const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0 + {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2 + {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0 + {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2 + + // Possible masks for the low 2 bits + const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; + + const ushort4 hm = mm[2*ip + il/2]; + + const int shift = 2*il; + const float v1 = il == 0 ? 4.f : 64.f; + const float v2 = 4.f * v1; + + const uint16_t s_shift1 = 4*ip; + const uint16_t s_shift2 = s_shift1 + il; + + const int q_offset = 32*ip + l0; + const int y_offset = 128*ip + 32*il + l0; + + const int step = sizeof(block_q3_K) * nb / 2; + + device const float * y1 = yy + ix*QK_K + y_offset; + + uint32_t scales32, aux32; + thread uint16_t * scales16 = (thread uint16_t *)&scales32; + thread const int8_t * scales = (thread const int8_t *)&scales32; + + float sumf1[2] = {0.f}; + float sumf2[2] = {0.f}; + for (int i = ix; i < nb; i += 4) { + + for (int l = 0; l < 8; ++l) { + yl[l+ 0] = y1[l+ 0]; + yl[l+ 8] = y1[l+16]; + yl[l+16] = y1[l+32]; + yl[l+24] = y1[l+48]; + } + + device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); + device const uint16_t * a = (device const uint16_t *)(x[i].scales); + device const half * dh = &x[i].d; + + for (int row = 0; row < 2; ++row) { + + const float d_all = (float)dh[0]; + + scales16[0] = a[4]; + scales16[1] = a[5]; + aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; + scales16[0] = a[il+0]; + scales16[1] = a[il+1]; + scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; + + float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; + for (int l = 0; l < n; l += 2) { + const int32_t qs = q[l/2]; + s1 += yl[l+0] * (qs & qm[il/2][0]); + s2 += yl[l+1] * (qs & qm[il/2][1]); + s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); + s4 += yl[l+16] * (qs & qm[il/2][2]); + s5 += yl[l+17] * (qs & qm[il/2][3]); + s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); + } + float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[0] - 32); + sumf2[row] += d2 * (scales[2] - 32); + + s1 = s2 = s3 = s4 = s5 = s6 = 0; + for (int l = 0; l < n; l += 2) { + const int32_t qs = q[l/2+8]; + s1 += yl[l+8] * (qs & qm[il/2][0]); + s2 += yl[l+9] * (qs & qm[il/2][1]); + s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); + s4 += yl[l+24] * (qs & qm[il/2][2]); + s5 += yl[l+25] * (qs & qm[il/2][3]); + s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); + } + d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[1] - 32); + sumf2[row] += d2 * (scales[3] - 32); + + q += step; + h += step; + a += step; + dh += step; + + } + + y1 += 4 * QK_K; + + } + + for (int row = 0; row < 2; ++row) { + const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); + sumf1[row] = simd_sum(sumf); + } + if (tiisg == 0) { + for (int row = 0; row < 2; ++row) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row]; + } + } +} + +[[host_name("kernel_mul_mv_q3_K_f32")]] +kernel void kernel_mul_mv_q3_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_q4_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int iq = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = r0 * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; + float yh[16]; + float sumf[N_DST]={0.f}, all_sum; + + const int step = sizeof(block_q4_K) * nb / 2; + + device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + + for (int ib = ix; ib < nb; ib += 4) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; + yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; + yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; + yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; + } + + device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq; + device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + device const uint16_t * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); + acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); + acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); + acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); + acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); + acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); + acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); + acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + + (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += step; + sc += step; + dh += step; + } + + y4 += 4 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_q4_K_f32")]] +kernel void kernel_mul_mv_q4_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_q5_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float sumf[2]={0.f}; + + const int step = sizeof(block_q5_K) * nb; + + float yl[16], yh[16]; + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = tiisg/4; + const int ix = tiisg%4; + const int iq = tid/4; + const int ir = tid%4; + const int n = 8; + + const int l0 = n*ir; + const int q_offset = 32*iq + l0; + const int y_offset = 64*iq + l0; + + const uint8_t hm1 = 1u << (2*iq); + const uint8_t hm2 = hm1 << 1; + const uint8_t hm3 = hm1 << 4; + const uint8_t hm4 = hm2 << 4; + + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + + device const float * y1 = yy + ix*QK_K + y_offset; + + for (int i = ix; i < nb; i += 4) { + + device const uint8_t * q1 = x[i].qs + q_offset; + device const uint8_t * qh = x[i].qh + l0; + device const half * dh = &x[i].d; + device const uint16_t * a = (device const uint16_t *)x[i].scales + iq; + + device const float * y2 = y1 + 128; + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < 8; ++l) { + yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; + yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; + yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; + yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; + } + + for (int row = 0; row < 2; ++row) { + + device const uint8_t * q2 = q1 + 64; + + sc16[0] = a[0] & kmask1; + sc16[1] = a[2] & kmask1; + sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); + sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); + + float4 acc1 = {0.f}; + float4 acc2 = {0.f}; + for (int l = 0; l < n; ++l) { + uint8_t h = qh[l]; + acc1[0] += yl[l+0] * (q1[l] & 0x0F); + acc1[1] += yl[l+8] * (q1[l] & 0xF0); + acc1[2] += yh[l+0] * (q2[l] & 0x0F); + acc1[3] += yh[l+8] * (q2[l] & 0xF0); + acc2[0] += h & hm1 ? yl[l+0] : 0.f; + acc2[1] += h & hm2 ? yl[l+8] : 0.f; + acc2[2] += h & hm3 ? yh[l+0] : 0.f; + acc2[3] += h & hm4 ? yh[l+8] : 0.f; + } + const float dall = dh[0]; + const float dmin = dh[1]; + sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + + sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + + sc8[4] * (acc1[2] + 16.f*acc2[2]) + + sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += step; + qh += step; + dh += step/2; + a += step/2; + + } + + y1 += 4 * QK_K; + + } + + for (int row = 0; row < 2; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q5_K_f32")]] +kernel void kernel_mul_mv_q5_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_q6_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const uint8_t kmask1 = 0x03; + const uint8_t kmask2 = 0x0C; + const uint8_t kmask3 = 0x30; + const uint8_t kmask4 = 0xC0; + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int im = tgpig.z; + + const int row = 2 * r0 + sgitg; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float sumf = 0; + + const int tid = tiisg/2; + const int ix = tiisg%2; + const int ip = tid/8; // 0 or 1 + const int il = tid%8; + const int n = 4; + const int l0 = n*il; + const int is = 8*ip + l0/16; + + const int y_offset = 128*ip + l0; + const int q_offset_l = 64*ip + l0; + const int q_offset_h = 32*ip + l0; + + for (int i = ix; i < nb; i += 2) { + + device const uint8_t * q1 = x[i].ql + q_offset_l; + device const uint8_t * q2 = q1 + 32; + device const uint8_t * qh = x[i].qh + q_offset_h; + device const int8_t * sc = x[i].scales + is; + + device const float * y = yy + i * QK_K + y_offset; + + const float dall = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + + sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + + } + + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + row] = tot; + } +} + +[[host_name("kernel_mul_mv_q6_K_f32")]] +kernel void kernel_mul_mv_q6_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +// ======================= "True" 2-bit + +void kernel_mul_mv_iq2_xxs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + { + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_xxs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + device const uint8_t * aux8 = (device const uint8_t *)q2; + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = db * (0.5f + (aux32 >> 28)); + + float sum = 0; + for (int l = 0; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); + const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d * sum; + + dh += nb*sizeof(block_iq2_xxs)/2; + q2 += nb*sizeof(block_iq2_xxs)/2; + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_xxs_f32")]] +kernel void kernel_mul_mv_iq2_xxs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq2_xs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512); + { + int nval = 8; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_xs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const uint8_t * sc = xr->scales + ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const uint8_t ls1 = sc[0] & 0xf; + const uint8_t ls2 = sc[0] >> 4; + const float d1 = db * (0.5f + ls1); + const float d2 = db * (0.5f + ls2); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < 2; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); + const uint8_t signs = shared_signs[(q2[l] >> 9)]; + for (int j = 0; j < 8; ++j) { + sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + for (int l = 2; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); + const uint8_t signs = shared_signs[(q2[l] >> 9)]; + for (int j = 0; j < 8; ++j) { + sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d1 * sum1 + d2 * sum2; + + dh += nb*sizeof(block_iq2_xs)/2; + q2 += nb*sizeof(block_iq2_xs)/2; + sc += nb*sizeof(block_iq2_xs); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_xs_f32")]] +kernel void kernel_mul_mv_iq2_xs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq3_xxs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + { + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq3_xxs * xr = x + ibl; + device const uint8_t * q3 = xr->qs + 8 * ib; + device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float d = db * (0.5f + (aux32 >> 28)); + + float2 sum = {0}; + for (int l = 0; l < 4; ++l) { + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]); + const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + } + sumf[row] += d * (sum[0] + sum[1]); + + dh += nb*sizeof(block_iq3_xxs)/2; + q3 += nb*sizeof(block_iq3_xxs); + gas += nb*sizeof(block_iq3_xxs)/2; + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f; + } + } +} + +[[host_name("kernel_mul_mv_iq3_xxs_f32")]] +kernel void kernel_mul_mv_iq3_xxs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq3_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; + { + int nval = 8; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq3_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 8 * ib; + device const uint8_t * qh = xr->qh + ib; + device const uint8_t * sc = xr->scales + (ib/2); + device const uint8_t * signs = xr->signs + 4 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); + + float2 sum = {0}; + for (int l = 0; l < 4; ++l) { + const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values; + const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values; + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); + for (int j = 0; j < 4; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); + sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); + } + } + sumf[row] += d * (sum[0] + sum[1]); + + dh += nb*sizeof(block_iq3_s)/2; + qs += nb*sizeof(block_iq3_s); + qh += nb*sizeof(block_iq3_s); + sc += nb*sizeof(block_iq3_s); + signs += nb*sizeof(block_iq3_s); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_iq3_s_f32")]] +kernel void kernel_mul_mv_iq3_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq2_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + //{ + // int nval = 32; + // int pos = (32*sgitg + tiisg)*nval; + // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i]; + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint8_t * qh = xr->qh + ib; + device const uint8_t * sc = xr->scales + ib; + device const uint8_t * signs = qs + QK_K/8; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const float d1 = db * (0.5f + (sc[0] & 0xf)); + const float d2 = db * (0.5f + (sc[0] >> 4)); + + float2 sum = {0}; + for (int l = 0; l < 2; ++l) { + //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + for (int j = 0; j < 8; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); + sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); + } + } + sumf[row] += d1 * sum[0] + d2 * sum[1]; + + dh += nb*sizeof(block_iq2_s)/2; + qs += nb*sizeof(block_iq2_s); + qh += nb*sizeof(block_iq2_s); + sc += nb*sizeof(block_iq2_s); + signs += nb*sizeof(block_iq2_s); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_s_f32")]] +kernel void kernel_mul_mv_iq2_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq1_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_value, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + float sumy = 0; + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + sumy += yl[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq1_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint16_t * qh = xr->qh + ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); + + float sum = 0; + for (int j = 0; j < 4; ++j) { + sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) + + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); + } + sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); + + dh += nb*sizeof(block_iq1_s)/2; + qs += nb*sizeof(block_iq1_s); + qh += nb*sizeof(block_iq1_s)/2; + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq1_m_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_value, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + iq1m_scale_t scale; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + float4 sumy = {0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+24]; sumy[3] += yl[i+24]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq1_m * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint8_t * qh = xr->qh + 2 * ib; + device const uint16_t * sc = (device const uint16_t *)xr->scales; + + for (int row = 0; row < N_DST; row++) { + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); + + float2 sum = {0.f}; + for (int j = 0; j < 4; ++j) { + sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); + sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); + } + const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + + sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + + (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); + + sc += nb*sizeof(block_iq1_m)/2; + qs += nb*sizeof(block_iq1_m); + qh += nb*sizeof(block_iq1_m); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq4_nl_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK4_NL; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/2; // 0...15 + const int it = tiisg%2; // 0 or 1 + + shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK4_NL + it * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ib = ix; ib < nb; ib += 16) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2; ++row) { + + device const block_iq4_nl & xb = x[row*nb + ib]; + device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] | (q4[1] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[2] | (q4[3] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 16 * QK4_NL; + } + + for (int row = 0; row < 2; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq4_xs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/16; // 0 or 1 + const int it = tiisg%16; // 0...15 + const int ib = it/2; + const int il = it%2; + + shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK_K + ib * 32 + il * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ibl = ix; ibl < nb; ibl += 2) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2; ++row) { + + device const block_iq4_xs & xb = x[row*nb + ibl]; + device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] & 0x0f0f0f0f; + aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[1] & 0x0f0f0f0f; + aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; + sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 2 * QK_K; + } + + for (int row = 0; row < 2; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_iq1_s_f32")]] +kernel void kernel_mul_mv_iq1_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq1_m_f32")]] +kernel void kernel_mul_mv_iq1_m_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq4_nl_f32")]] +kernel void kernel_mul_mv_iq4_nl_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq4_xs_f32")]] +kernel void kernel_mul_mv_iq4_xs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +//============================= templates and their specializations ============================= + +// NOTE: this is not dequantizing - we are simply fitting the template +template +void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { + float4x4 temp = *(((device float4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + +template +void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { + half4x4 temp = *(((device half4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + +template +void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; + reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; + } +} + +template +void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 2); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float m = xb->m; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; + reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; + } +} + +template +void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + md; + reg[i/2][2*(i%2)+1] = d * x1 + md; + } +} + +template +void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + m; + reg[i/2][2*(i%2)+1] = d * x1 + m; + } +} + +template +void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const half d = xb->d; + + for (int i = 0; i < 16; i++) { + reg[i/4][i%4] = (qs[i + 16*il] * d); + } +} + +template +void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { + const float d = xb->d; + const float min = xb->dmin; + device const uint8_t * q = (device const uint8_t *)xb->qs; + float dl, ml; + uint8_t sc = xb->scales[il]; + + q = q + 32*(il/8) + 16*(il&1); + il = (il/2)%4; + + half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint8_t * q = (device const uint8_t *)xb->qs; + device const uint8_t * h = (device const uint8_t *)xb->hmask; + device const int8_t * scales = (device const int8_t *)xb->scales; + + q = q + 32 * (il/8) + 16 * (il&1); + h = h + 16 * (il&1); + uint8_t m = 1 << (il/2); + uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ + ((il/4)>0 ? 12 : 3); + uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; + uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; + int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) + : (scale_2&kmask2) | ((scale_1&kmask1) << 4); + float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); + const float ml = 4.f * dl; + + il = (il/2) & 3; + const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl *= coef; + + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); + } +} + +static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { + return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} + : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; +} + +template +void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { + device const uchar * q = xb->qs; + + short is = (il/4) * 2; + q = q + (il/4) * 32 + 16 * (il&1); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const float d = il < 2 ? xb->d : xb->d / 16.h; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; + + const ushort mask = il<2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { + device const uint8_t * q = xb->qs; + device const uint8_t * qh = xb->qh; + + short is = (il/4) * 2; + q = q + 32 * (il/4) + 16 * (il&1); + qh = qh + 16 * (il&1); + uint8_t ul = 1 << (il/2); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const float d = il < 2 ? xb->d : xb->d / 16.f; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; + + const ushort mask = il<2 ? 0x0F : 0xF0; + const float qh_val = il<2 ? 16.f : 256.f; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; + } +} + +template +void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint8_t * ql = (device const uint8_t *)xb->ql; + device const uint8_t * qh = (device const uint8_t *)xb->qh; + device const int8_t * scales = (device const int8_t *)xb->scales; + + ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); + qh = qh + 32*(il/8) + 16*(il&1); + float sc = scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; + + const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; + const float coef = il>1 ? 1.f/16.f : 1.f; + const float ml = d_all * sc * 32.f; + const float dl = d_all * sc * coef; + for (int i = 0; i < 16; ++i) { + const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) + : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); + reg[i/4][i%4] = dl * q - ml; + } +} + +template +void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's. + device const uint16_t * q2 = xb->qs + 4*ib32; + const uint32_t aux32_g = q2[0] | (q2[1] << 16); + const uint32_t aux32_s = q2[2] | (q2[3] << 16); + thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g; + const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]); + uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]); + signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint16_t * q2 = xb->qs + 4*ib32; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511)); + uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511)); + signs = ksigns_iq2xs[q2[2*il+1] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * q3 = xb->qs + 8*ib32; + device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32; + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f; + constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]); + constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]); + uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127]; + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); + reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); + } + grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]); + grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]); + signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127]; + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); + reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); + } +} + +template +void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 8*ib32; + device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); + constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256))); + constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); + reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); + } + grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256))); + grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256))); + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); + reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); + } +} + +template +void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * signs = qs + QK_K/8; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300))); + for (int i = 0; i < 8; ++i) { + reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]); + reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]); + } +} + +template +void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + const float d = xb->d; + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint16_t * qh = xb->qh; + const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1); + const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA); + const uint16_t h = qh[ib32] >> 6*il; + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml; + reg[1][i] = dl * (grid1[i] >> 4) + ml; + reg[2][i] = dl * (grid2[i] & 0xf) + ml; + reg[3][i] = dl * (grid2[i] >> 4) + ml; + } +} + +template +void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + device const uint16_t * sc = (device const uint16_t *)xb->scales; + + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const float d = scale.f16; + + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * qh = xb->qh + 2*ib32 + il; + + const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1); + const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml1; + reg[1][i] = dl * (grid1[i] >> 4) + ml1; + reg[2][i] = dl * (grid2[i] & 0xf) + ml2; + reg[3][i] = dl * (grid2[i] >> 4) + ml2; + } +} + +template +void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) { + device const uint16_t * q4 = (device const uint16_t *)xb->qs; + const float d = xb->d; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +} + +template +void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32; + const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4); + const float d = (float)xb->d * (ls - 32); + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +} + +template +kernel void kernel_get_rows( + device const void * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + //const int64_t i = tgpig; + //const int64_t r = ((device int32_t *) src1)[i]; + + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { + float4x4 temp; + dequantize_func( + ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); + *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + } +} + +kernel void kernel_get_rows_f32( + device const void * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_f16( + device const void * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_i32( + device const void * src0, + device const char * src1, + device int32_t * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + + +#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A +#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B +#define BLOCK_SIZE_K 32 +#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A +#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B +#define THREAD_PER_BLOCK 128 +#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers +#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers +#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 +#define SG_MAT_ROW 8 + +// each block_q contains 16*nl weights +template +void kernel_mul_mm_impl(device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = (threadgroup half *)(shared_memory); + threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + + const uint r0 = tgpig.y; + const uint r1 = tgpig.x; + const uint im = tgpig.z; + + // if this block is of 64x32 shape or smaller + short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + + // a thread shouldn't load data outside of the matrix + short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_half8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 c_res[8]; + for (int i = 0; i < 8; i++){ + c_res[i] = make_filled_simdgroup_matrix(0.f); + } + + short il = (tiitg % THREAD_PER_ROW); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); + ushort offset1 = il/nl; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + device const float * y = (device const float *)(src1 + + nb12 * im + + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + // load data and store to threadgroup memory + half4x4 temp_a; + dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(16) + for (int i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ + + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + } + + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // load matrices from threadgroup memory and conduct outer products + threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); + threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + + #pragma unroll(4) + for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) + for (int i = 0; i < 4; i++) { + simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); + } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) + for (int i = 0; i < 2; i++) { + simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); + } + + lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; + lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + + #pragma unroll(8) + for (int i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + } + } + } + + if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { + device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ + + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); + } + } else { + // block is smaller than 64x32, we should avoid writing data outside of the matrix + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; + if (sgitg == 0) { + for (int i = 0; i < n_rows; i++) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); + } + } + } + } +} + +// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids +template +void kernel_mul_mm_id_impl( + device const uchar * src0, + device const uchar * src1, + threadgroup ushort2 * rowids, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + int64_t ne1, + int64_t ne0ne1, + threadgroup uchar * shared_memory, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = (threadgroup half *)(shared_memory); + threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + + const uint r0 = tgpig.y; + const uint r1 = tgpig.x; + + if (r1 * BLOCK_SIZE_N >= ne1) return; + + // if this block is of 64x32 shape or smaller + short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + + // a thread shouldn't load data outside of the matrix + short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_half8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 c_res[8]; + for (int i = 0; i < 8; i++){ + c_res[i] = make_filled_simdgroup_matrix(0.f); + } + short il = (tiitg % THREAD_PER_ROW); + + ushort offset1 = il/nl; + + threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col]; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1; + device const float * y = (device const float *)(src1 + + nb12 * id[1] + + nb11 * (id[0] % ne11) + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + // load data and store to threadgroup memory + half4x4 temp_a; + dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ + + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + } + + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // load matrices from threadgroup memory and conduct outer products + threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); + threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + + for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + for (int i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); + } + simdgroup_barrier(mem_flags::mem_none); + for (int i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); + } + + lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; + lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + + for (int i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + } + } + } + + { + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + device float * C = dst + (BLOCK_SIZE_M * r0); + if (sgitg == 0) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; + int joff = jid[0] * ne0 + jid[1] * ne0ne1; + for (int i = 0; i < n_rows; i++) { + *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M); + } + } + } + } +} + +template +kernel void kernel_mul_mm(device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mm_impl( + src0, + src1, + dst, + ne00, + ne02, + nb01, + nb02, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + shared_memory, + tgpig, + tiitg, + sgitg); +} + +template +kernel void kernel_mul_mm_id( + device const uchar * src0s, + device const uchar * src1, + device float * dst, + device const uchar * ids, + constant int64_t & nei0, + constant int64_t & nei1, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int32_t i02 = tgpig.z; + tgpig.z = 0; + + device const uchar * src0 = src0s + i02*nb02; + + // row indices + threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + + // TODO: parallelize this loop + int64_t _ne1 = 0; + for (ushort ii1 = 0; ii1 < nei1; ii1++) { + for (ushort ii0 = 0; ii0 < nei0; ii0++) { + int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + if (id == i02) { + //if (tiitg == 0) { + rowids[_ne1] = ushort2(ii0, ii1); + //} + _ne1++; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + kernel_mul_mm_id_impl( + src0, + src1, + rowids, + dst, + ne00, + ne02, + nb01, + nb02, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + _ne1, + ne0*ne1, + shared_memory, + tgpig, + tiitg, + sgitg); +} + +#define QK_NL 16 + +// +// get rows +// + +typedef void (get_rows_t)( + device const void * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3, uint, uint3); + +//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; +//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; + +// +// matrix-matrix multiplication +// + +typedef decltype(kernel_mul_mm) mat_mm_t; + +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; + +// +// indirect matrix-matrix multiplication +// + +typedef decltype(kernel_mul_mm_id) mat_mm_id_t; + +template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; + +// +// matrix-vector multiplication +// + +typedef void (kernel_mul_mv_impl_t)( + device const char * src0, + device const char * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg); + +typedef void (kernel_mul_mv2_impl_t)( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg); + +template +void mmv_fn( + device const char * src0, + device const char * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + int64_t ne13, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint64_t nb1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { + impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg); +} + +template +void mmv_fn( + device const char * src0, + device const char * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + int64_t ne13, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint64_t nb1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { + impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); +} + +typedef decltype(mmv_fn) mul_mv_impl_fn_t; + +template +kernel void kernel_mul_mv_id( + device const char * src0s, + device const char * src1, + device float * dst, + device const char * ids, + constant int64_t & nei0, + constant int64_t & nei1, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int iid1 = tgpig.z/nei0; + const int idx = tgpig.z%nei0; + + tgpig.z = 0; + + const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx]; + + const int64_t i11 = idx % ne11; + const int64_t i12 = iid1; + + const int64_t i1 = idx; + const int64_t i2 = i12; + + device const char * src0_cur = src0s + i02*nb02; + device const char * src1_cur = src1 + i11*nb11 + i12*nb12; + device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; + + impl_fn( + /* src0 */ src0_cur, + /* src1 */ src1_cur, + /* dst */ dst_cur, + /* ne00 */ ne00, + /* ne01 */ ne01, + /* ne02 */ 1,//ne02, + /* nb00 */ nb00, + /* nb01 */ nb01, + /* nb02 */ nb02, + /* ne10 */ ne10, + /* ne11 */ 1,//ne11, + /* ne12 */ 1,//ne12, + /* ne13 */ 1,//ne13, + /* nb10 */ nb10, + /* nb11 */ nb11, + /* nb12 */ nb12, + /* ne0 */ ne0, + /* ne1 */ 1,//ne1, + /* nb1 */ nb1, + /* r2 */ 1, + /* r3 */ 1, + shared_values, + tgpig, + tiitg, + tiisg, + sgitg); +} + +typedef decltype(kernel_mul_mv_id>) kernel_mul_mv_id_t; + +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; + diff --git a/cpp/ggml-quants.c b/cpp/ggml-quants.c index 9b411ddf..7dd0f968 100644 --- a/cpp/ggml-quants.c +++ b/cpp/ggml-quants.c @@ -4,8 +4,6 @@ #include "ggml-quants.h" #include "ggml-impl.h" -#define LM_GGML_COMMON_IMPL_C -#include "ggml-common.h" #include #include @@ -660,7 +658,7 @@ static inline __m128i packNibbles( __m256i bytes ) { #endif //__loongarch_asx // reference implementation for deterministic creation of model files -void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { +void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; assert(k % qk == 0); @@ -698,11 +696,11 @@ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict } void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q4_0_reference(x, y, k); + quantize_row_q4_0_ref(x, y, k); } -void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int64_t k) { +void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, int64_t k) { const int qk = QK4_1; assert(k % qk == 0); @@ -740,10 +738,10 @@ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict } void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q4_1_reference(x, y, k); + quantize_row_q4_1_ref(x, y, k); } -void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int64_t k) { +void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, int64_t k) { static const int qk = QK5_0; assert(k % qk == 0); @@ -788,10 +786,10 @@ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict } void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q5_0_reference(x, y, k); + quantize_row_q5_0_ref(x, y, k); } -void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int64_t k) { +void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, int64_t k) { const int qk = QK5_1; assert(k % qk == 0); @@ -836,11 +834,11 @@ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict } void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q5_1_reference(x, y, k); + quantize_row_q5_1_ref(x, y, k); } // reference implementation for deterministic creation of model files -void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int64_t k) { +void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -1078,6 +1076,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) } vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); + } #elif defined(__loongarch_asx) for (int i = 0; i < nb; i++) { @@ -1145,12 +1144,12 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) #else LM_GGML_UNUSED(nb); // scalar - quantize_row_q8_0_reference(x, y, k); + quantize_row_q8_0_ref(x, y, k); #endif } // reference implementation for deterministic creation of model files -void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int64_t k) { +void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, int64_t k) { assert(QK8_1 == 32); assert(k % QK8_1 == 0); const int nb = k / QK8_1; @@ -1437,6 +1436,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) accv = vec_add(accv, vec_sld(accv, accv, 4)); accv = vec_add(accv, vec_sld(accv, accv, 8)); y[i].s = LM_GGML_FP32_TO_FP16(d * vec_extract(accv, 0)); + } #elif defined(__loongarch_asx) for (int i = 0; i < nb; i++) { @@ -1508,7 +1508,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) #else LM_GGML_UNUSED(nb); // scalar - quantize_row_q8_1_reference(x, y, k); + quantize_row_q8_1_ref(x, y, k); #endif } @@ -1899,7 +1899,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * //========================- 2-bit (de)-quantization -void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int64_t k) { +void quantize_row_q2_K_ref(const float * restrict x, block_q2_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2002,7 +2002,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int6 } void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) { - quantize_row_q2_K_reference(x, vy, k); + quantize_row_q2_K_ref(x, vy, k); } static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights, @@ -2226,7 +2226,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_Q2_K, n_per_row); if (!quant_weights) { - quantize_row_q2_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q2_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -2241,7 +2241,7 @@ size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nr //========================= 3-bit (de)-quantization -void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int64_t k) { +void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2368,7 +2368,7 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int6 } void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) { - quantize_row_q3_K_reference(x, vy, k); + quantize_row_q3_K_ref(x, vy, k); } static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) { @@ -2458,7 +2458,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_Q3_K, n_per_row); if (!quant_weights) { - quantize_row_q3_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q3_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -2473,7 +2473,7 @@ size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nr // ====================== 4-bit (de)-quantization -void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int64_t k) { +void quantize_row_q4_K_ref(const float * restrict x, block_q4_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2572,7 +2572,7 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int6 void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q4_K * restrict y = vy; - quantize_row_q4_K_reference(x, y, k); + quantize_row_q4_K_ref(x, y, k); } static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) { @@ -2651,7 +2651,7 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_Q4_K, n_per_row); if (!quant_weights) { - quantize_row_q4_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q4_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -2666,7 +2666,7 @@ size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nr // ====================== 5-bit (de)-quantization -void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int64_t k) { +void quantize_row_q5_K_ref(const float * restrict x, block_q5_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int64_t nb = k / QK_K; @@ -2783,7 +2783,7 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int6 void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q5_K * restrict y = vy; - quantize_row_q5_K_reference(x, y, k); + quantize_row_q5_K_ref(x, y, k); } static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) { @@ -2882,7 +2882,7 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_Q5_K, n_per_row); if (!quant_weights) { - quantize_row_q5_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q5_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -2897,7 +2897,7 @@ size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nr // ====================== 6-bit (de)-quantization -void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int64_t k) { +void quantize_row_q6_K_ref(const float * restrict x, block_q6_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int64_t nb = k / QK_K; @@ -3001,7 +3001,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int6 void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q6_K * restrict y = vy; - quantize_row_q6_K_reference(x, y, k); + quantize_row_q6_K_ref(x, y, k); } static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) { @@ -3091,7 +3091,7 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_Q6_K, n_per_row); if (!quant_weights) { - quantize_row_q6_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q6_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -3108,7 +3108,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri static_assert(QK4_0 == 32, "QK4_0 must be 32"); if (!quant_weights) { - quantize_row_q4_0_reference(x, y, n_per_row); + quantize_row_q4_0_ref(x, y, n_per_row); return; } @@ -3134,7 +3134,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * lm_ggml_row_size(LM_GGML_TYPE_Q4_0, n_per_row); } size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_Q4_0, n_per_row); @@ -3151,7 +3151,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri static_assert(QK4_1 == 32, "QK4_1 must be 32"); if (!quant_weights) { - quantize_row_q4_1_reference(x, y, n_per_row); + quantize_row_q4_1_ref(x, y, n_per_row); return; } @@ -3179,7 +3179,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q4_1_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q4_1_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * lm_ggml_row_size(LM_GGML_TYPE_Q4_1, n_per_row); } size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_Q4_1, n_per_row); @@ -3196,7 +3196,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri static_assert(QK5_0 == 32, "QK5_0 must be 32"); if (!quant_weights) { - quantize_row_q5_0_reference(x, y, n_per_row); + quantize_row_q5_0_ref(x, y, n_per_row); return; } @@ -3233,7 +3233,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q5_0_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q5_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * lm_ggml_row_size(LM_GGML_TYPE_Q5_0, n_per_row); } size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_Q5_0, n_per_row); @@ -3250,7 +3250,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri static_assert(QK5_1 == 32, "QK5_1 must be 32"); if (!quant_weights) { - quantize_row_q5_1_reference(x, y, n_per_row); + quantize_row_q5_1_ref(x, y, n_per_row); return; } @@ -3286,7 +3286,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q5_1_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q5_1_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * lm_ggml_row_size(LM_GGML_TYPE_Q5_1, n_per_row); } size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_Q5_1, n_per_row); @@ -3302,7 +3302,7 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nr size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { (void)quant_weights; // not used const size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, n_per_row); - quantize_row_q8_0_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q8_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * row_size; } @@ -3590,7 +3590,7 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, //===================================== Q8_K ============================================== -void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int64_t k) { +void quantize_row_q8_K_ref(const float * restrict x, block_q8_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int64_t nb = k / QK_K; @@ -3641,7 +3641,7 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int6 } void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q8_K_reference(x, y, k); + quantize_row_q8_K_ref(x, y, k); } //===================================== Dot ptoducts ================================= @@ -3808,59 +3808,61 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); - vst1_f32(s, vget_low_f32(sumv2)); + vst1_f32(s, vget_low_f32(sumv2)); vst1_f32(s + bs, vget_high_f32(sumv2)); return; } #endif + + int ib = 0; + float sumf = 0; + #if defined(__ARM_FEATURE_SVE) - const svbool_t ptrueh = svptrue_pat_b8(SV_VL16); - const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh); + if (svcntb() == QK8_0) { + const svbool_t ptrueh = svptrue_pat_b8(SV_VL16); + const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh); - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); - assert(nb % 2 == 0); // TODO: handle odd nb + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; - for (int i = 0; i < nb; i += 2) { - const block_q4_0 * restrict x0 = &x[i + 0]; - const block_q4_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; + // load x + const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); + const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); - // load x - const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); - const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04)); - // 4-bit -> 8-bit - const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04)); - const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04)); + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); - // sub 8 - const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); - const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + // dot product + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } - // dot product - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); } - - *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); #elif defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); - assert(nb % 2 == 0); // TODO: handle odd nb - - for (int i = 0; i < nb; i += 2) { - const block_q4_0 * restrict x0 = &x[i + 0]; - const block_q4_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; const uint8x16_t m4b = vdupq_n_u8(0x0F); const int8x16_t s8b = vdupq_n_s8(0x8); @@ -3894,23 +3896,23 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); // Main loop - for (int i = 0; i < nb; ++i) { + for (; ib < nb; ++ib) { /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps( LM_GGML_FP16_TO_FP32(x[i].d) * LM_GGML_FP16_TO_FP32(y[i].d) ); + const __m256 d = _mm256_set1_ps( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) ); - __m256i qx = bytes_from_nibbles_32(x[i].qs); + __m256i qx = bytes_from_nibbles_32(x[ib].qs); // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. const __m256i off = _mm256_set1_epi8( 8 ); qx = _mm256_sub_epi8( qx, off ); - __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); const __m256 q = mul_sum_i8_pairs_float(qx, qy); @@ -3918,28 +3920,28 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void acc = _mm256_fmadd_ps( d, q, acc ); } - *s = hsum_float_8(acc); + sumf = hsum_float_8(acc); #elif defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); // Main loop - for (int i = 0; i < nb; ++i) { + for (; ib < nb; ++ib) { // Compute combined scale for the block - const __m256 d = _mm256_set1_ps( LM_GGML_FP16_TO_FP32(x[i].d) * LM_GGML_FP16_TO_FP32(y[i].d) ); + const __m256 d = _mm256_set1_ps( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) ); const __m128i lowMask = _mm_set1_epi8(0xF); const __m128i off = _mm_set1_epi8(8); - const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); + const __m128i tmp = _mm_loadu_si128((const __m128i *)x[ib].qs); __m128i bx_0 = _mm_and_si128(lowMask, tmp); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs); bx_0 = _mm_sub_epi8(bx_0, off); const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); - by_0 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + by_0 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16)); bx_0 = _mm_sub_epi8(bx_0, off); const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0); @@ -3950,7 +3952,7 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); } - *s = hsum_float_8(acc); + sumf = hsum_float_8(acc); #elif defined(__SSSE3__) // set constants const __m128i lowMask = _mm_set1_epi8(0xF); @@ -3962,94 +3964,40 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void __m128 acc_2 = _mm_setzero_ps(); __m128 acc_3 = _mm_setzero_ps(); - // First round without accumulation - { - _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( LM_GGML_FP16_TO_FP32(x[0].d) * LM_GGML_FP16_TO_FP32(y[0].d) ); - - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); - - __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16)); - bx_1 = _mm_sub_epi8(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - - _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( LM_GGML_FP16_TO_FP32(x[1].d) * LM_GGML_FP16_TO_FP32(y[1].d) ); - - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); - - __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs); - bx_2 = _mm_sub_epi8(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - - __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16)); - bx_3 = _mm_sub_epi8(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - - // Convert int32_t to float - __m128 p0 = _mm_cvtepi32_ps(i32_0); - __m128 p1 = _mm_cvtepi32_ps(i32_1); - __m128 p2 = _mm_cvtepi32_ps(i32_2); - __m128 p3 = _mm_cvtepi32_ps(i32_3); - - // Apply the scale - acc_0 = _mm_mul_ps( d_0_1, p0 ); - acc_1 = _mm_mul_ps( d_0_1, p1 ); - acc_2 = _mm_mul_ps( d_2_3, p2 ); - acc_3 = _mm_mul_ps( d_2_3, p3 ); - } - - assert(nb % 2 == 0); // TODO: handle odd nb - - // Main loop - for (int i = 2; i < nb; i+=2) { - _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); + for (; ib + 1 < nb; ib += 2) { + _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( LM_GGML_FP16_TO_FP32(x[i].d) * LM_GGML_FP16_TO_FP32(y[i].d) ); + const __m128 d_0_1 = _mm_set1_ps( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) ); - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs); __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs); bx_0 = _mm_sub_epi8(bx_0, off); const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16)); bx_1 = _mm_sub_epi8(bx_1, off); const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); + _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( LM_GGML_FP16_TO_FP32(x[i + 1].d) * LM_GGML_FP16_TO_FP32(y[i + 1].d) ); + const __m128 d_2_3 = _mm_set1_ps( LM_GGML_FP16_TO_FP32(x[ib + 1].d) * LM_GGML_FP16_TO_FP32(y[ib + 1].d) ); - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); bx_2 = _mm_sub_epi8(bx_2, off); const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16)); bx_3 = _mm_sub_epi8(bx_3, off); const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); @@ -4072,18 +4020,16 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void acc_3 = _mm_add_ps(p3_d, acc_3); } - *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); + sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); #elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - size_t vl = __riscv_vsetvl_e8m1(qk/2); - for (int i = 0; i < nb; i++) { + for (; ib < nb; ++ib) { // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); // mask and store lower part of x, and then upper part vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); @@ -4106,30 +4052,29 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - sumf += sumi*LM_GGML_FP16_TO_FP32(x[i].d)*LM_GGML_FP16_TO_FP32(y[i].d); + sumf += sumi*LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d); } - *s = sumf; - #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); const vector unsigned char v4 = vec_splats((unsigned char)0x4); const vector signed char v8 = vec_splats((signed char)0x8); vector float vsumf0 = vec_splats(0.0f); -#pragma GCC unroll 4 - for (int i = 0; i < nb; i++) { - __builtin_prefetch(x[i].qs, 0, 1); - __builtin_prefetch(y[i].qs, 0, 1); +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[i].d)); + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); vector float vd = vec_mul(vxd, vyd); - vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs); - vector signed char q8y0 = vec_xl( 0, y[i].qs); - vector signed char q8y1 = vec_xl(16, y[i].qs); + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); vector signed char q4x0 = vec_and(qxs, lowMask); vector signed char q4x1 = vec_sr(qxs, v4); @@ -4140,9 +4085,10 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); - qv0 = vec_add(qv0, qv1); + vector signed int vsumi0 = v0; - vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi0 = vec_sum4s(qv1, vsumi0); vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); } @@ -4150,24 +4096,24 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - *s = vec_extract(vsumf0, 0); + sumf = vec_extract(vsumf0, 0); #elif defined(__loongarch_asx) // Initialize accumulator with zeros __m256 acc = (__m256)__lasx_xvldi(0); // Main loop - for (int i = 0; i < nb; ++i) { + for (; ib < nb; ++ib) { /* Compute combined scale for the block */ - const __m256 d = __lasx_xvreplfr2vr_s( LM_GGML_FP16_TO_FP32(x[i].d) * LM_GGML_FP16_TO_FP32(y[i].d) ); + const __m256 d = __lasx_xvreplfr2vr_s( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) ); - __m256i qx = bytes_from_nibbles_32(x[i].qs); + __m256i qx = bytes_from_nibbles_32(x[ib].qs); // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. const __m256i off = __lasx_xvreplgr2vr_b( 8 ); qx = __lasx_xvsub_b( qx, off ); - __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0); + __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); const __m256 q = mul_sum_i8_pairs_float(qx, qy); @@ -4175,7 +4121,7 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void acc = __lasx_xvfmadd_s( d, q, acc ); } - *s = hsum_float_8(acc); + sumf = hsum_float_8(acc); #elif defined(__loongarch_sx) // set constants const __m128i low_mask = __lsx_vreplgr2vr_b(0xF); @@ -4187,89 +4133,38 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void __m128 acc_2 = __lsx_vldi(0); __m128 acc_3 = __lsx_vldi(0); - // First round without accumulation - { - _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = __lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[0].d) * LM_GGML_FP16_TO_FP32(y[0].d) ); - - const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[0].qs, 0); - - __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1); - __m128i by_0 = __lsx_vld((const __m128i *)y[0].qs, 0); - bx_0 = __lsx_vsub_b(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4)); - __m128i by_1 = __lsx_vld((const __m128i *)(y[0].qs + 16), 0); - bx_1 = __lsx_vsub_b(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = __lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[1].d) * LM_GGML_FP16_TO_FP32(y[1].d) ); - - const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[1].qs, 0); - - __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3); - __m128i by_2 = __lsx_vld((const __m128i *)y[1].qs, 0); - bx_2 = __lsx_vsub_b(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - - __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4)); - __m128i by_3 = __lsx_vld((const __m128i *)(y[1].qs + 16), 0); - bx_3 = __lsx_vsub_b(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - - // Convert int32_t to float - __m128 p0 = __lsx_vffint_s_w(i32_0); - __m128 p1 = __lsx_vffint_s_w(i32_1); - __m128 p2 = __lsx_vffint_s_w(i32_2); - __m128 p3 = __lsx_vffint_s_w(i32_3); - - // Apply the scale - acc_0 = __lsx_vfmul_s( d_0_1, p0 ); - acc_1 = __lsx_vfmul_s( d_0_1, p1 ); - acc_2 = __lsx_vfmul_s( d_2_3, p2 ); - acc_3 = __lsx_vfmul_s( d_2_3, p3 ); - } - - assert(nb % 2 == 0); // TODO: handle odd nb - - // Main loop - for (int i = 2; i < nb; i+=2) { + for (; ib + 1 < nb; ib += 2) { // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = __lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[i].d) * LM_GGML_FP16_TO_FP32(y[i].d) ); + const __m128 d_0_1 = __lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) ); - const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[i].qs, 0); + const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0); __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1); - __m128i by_0 = __lsx_vld((const __m128i *)y[i].qs, 0); + __m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0); bx_0 = __lsx_vsub_b(bx_0, off); const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4)); - __m128i by_1 = __lsx_vld((const __m128i *)(y[i].qs + 16), 0); + __m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0); bx_1 = __lsx_vsub_b(bx_1, off); const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - //_mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); - //_mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); + //_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = __lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[i + 1].d) * LM_GGML_FP16_TO_FP32(y[i + 1].d) ); + const __m128 d_2_3 = __lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[ib + 1].d) * LM_GGML_FP16_TO_FP32(y[ib + 1].d) ); - const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[i + 1].qs, 0); + const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0); __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3); - __m128i by_2 = __lsx_vld((const __m128i *)y[i + 1].qs, 0); + __m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0); bx_2 = __lsx_vsub_b(bx_2, off); const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4)); - __m128i by_3 = __lsx_vld((const __m128i *)(y[i + 1].qs + 16), 0); + __m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0); bx_3 = __lsx_vsub_b(bx_3, off); const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); @@ -4292,27 +4187,22 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void acc_3 = __lsx_vfadd_s(p3_d, acc_3); } - *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); - -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { + sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#endif + for (; ib < nb; ++ib) { int sumi = 0; for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[i].qs[j] & 0x0F) - 8; - const int v1 = (x[i].qs[j] >> 4) - 8; + const int v0 = (x[ib].qs[j] & 0x0F) - 8; + const int v1 = (x[ib].qs[j] >> 4) - 8; - sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + sumi += (v0 * y[ib].qs[j]) + (v1 * y[ib].qs[j + qk/2]); } - sumf += sumi*LM_GGML_FP16_TO_FP32(x[i].d)*LM_GGML_FP16_TO_FP32(y[i].d); + sumf += sumi*LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d); } *s = sumf; -#endif } void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { @@ -4398,11 +4288,15 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); sumv2 = vaddq_f32(sumv2, summs0); - vst1_f32(s, vget_low_f32(sumv2)); + vst1_f32(s, vget_low_f32 (sumv2)); vst1_f32(s + bs, vget_high_f32(sumv2)); return; } #endif + + int ib = 0; + float sumf = 0; + // TODO: add WASM SIMD #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); @@ -4410,13 +4304,11 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void float summs = 0; - assert(nb % 2 == 0); // TODO: handle odd nb - - for (int i = 0; i < nb; i += 2) { - const block_q4_1 * restrict x0 = &x[i + 0]; - const block_q4_1 * restrict x1 = &x[i + 1]; - const block_q8_1 * restrict y0 = &y[i + 0]; - const block_q8_1 * restrict y1 = &y[i + 1]; + for (; ib + 1 < nb; ib += 2) { + const block_q4_1 * restrict x0 = &x[ib + 0]; + const block_q4_1 * restrict x1 = &x[ib + 1]; + const block_q8_1 * restrict y0 = &y[ib + 0]; + const block_q8_1 * restrict y1 = &y[ib + 1]; summs += LM_GGML_FP16_TO_FP32(x0->m) * LM_GGML_FP16_TO_FP32(y0->s) + LM_GGML_FP16_TO_FP32(x1->m) * LM_GGML_FP16_TO_FP32(y1->s); @@ -4445,7 +4337,7 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; #elif defined(__AVX2__) || defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -4453,11 +4345,11 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void float summs = 0; // Main loop - for (int i = 0; i < nb; ++i) { - const float d0 = LM_GGML_FP16_TO_FP32(x[i].d); - const float d1 = LM_GGML_FP16_TO_FP32(y[i].d); + for (; ib < nb; ++ib) { + const float d0 = LM_GGML_FP16_TO_FP32(x[ib].d); + const float d1 = LM_GGML_FP16_TO_FP32(y[ib].d); - summs += LM_GGML_FP16_TO_FP32(x[i].m) * LM_GGML_FP16_TO_FP32(y[i].s); + summs += LM_GGML_FP16_TO_FP32(x[ib].m) * LM_GGML_FP16_TO_FP32(y[ib].s); const __m256 d0v = _mm256_set1_ps( d0 ); const __m256 d1v = _mm256_set1_ps( d1 ); @@ -4466,8 +4358,8 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i qx = bytes_from_nibbles_32(x[i].qs); - const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[i].qs ); + const __m256i qx = bytes_from_nibbles_32(x[ib].qs); + const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs ); const __m256 xy = mul_sum_us8_pairs_float(qx, qy); @@ -4479,18 +4371,16 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void #endif } - *s = hsum_float_8(acc) + summs; + sumf = hsum_float_8(acc) + summs; #elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - size_t vl = __riscv_vsetvl_e8m1(qk/2); - for (int i = 0; i < nb; i++) { + for (; ib < nb; ++ib) { // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); // mask and store lower part of x, and then upper part vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); @@ -4509,43 +4399,40 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - sumf += (LM_GGML_FP16_TO_FP32(x[i].d)*LM_GGML_FP16_TO_FP32(y[i].d))*sumi + LM_GGML_FP16_TO_FP32(x[i].m)*LM_GGML_FP16_TO_FP32(y[i].s); + sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_FP16_TO_FP32(x[ib].m)*LM_GGML_FP16_TO_FP32(y[ib].s); } - *s = sumf; - #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); const vector unsigned char v4 = vec_splats((unsigned char)0x4); vector float vsumf0 = vec_splats(0.0f); #pragma GCC unroll 4 - for (int i = 0; i < nb; i++) { - __builtin_prefetch(x[i].qs, 0, 1); - __builtin_prefetch(y[i].qs, 0, 1); + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[i].d)); + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); vector float vd = vec_mul(vxd, vyd); - vector float vxmin = vec_splats(LM_GGML_FP16_TO_FP32(x[i].m)); - vector float vys = {LM_GGML_FP16_TO_FP32(y[i].s), 0.0f, 0.0f, 0.0f}; + vector float vxmin = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].m)); + vector float vys = {LM_GGML_FP16_TO_FP32(y[ib].s), 0.0f, 0.0f, 0.0f}; vsumf0 = vec_madd(vxmin, vys, vsumf0); - vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs); - vector signed char q8y0 = vec_xl( 0, y[i].qs); - vector signed char q8y1 = vec_xl(16, y[i].qs); - - vector signed char q4x0 = vec_and(qxs, lowMask); - vector signed char q4x1 = vec_sr(qxs, v4); + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); - vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); + vector unsigned char q4x0 = (vector unsigned char)vec_and(qxs, lowMask); + vector unsigned char q4x1 = (vector unsigned char)vec_sr(qxs, v4); - qv0 = vec_add(qv0, qv1); + vector signed int vsumi0 = v0; - vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); + vsumi0 = vec_msum(q8y0, q4x0, vsumi0); + vsumi0 = vec_msum(q8y1, q4x1, vsumi0); vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); } @@ -4553,7 +4440,7 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - *s = vec_extract(vsumf0, 0); + sumf = vec_extract(vsumf0, 0); #elif defined(__loongarch_asx) // Initialize accumulator with zeros @@ -4562,11 +4449,11 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void float summs = 0; // Main loop - for (int i = 0; i < nb; ++i) { - const float d0 = LM_GGML_FP16_TO_FP32(x[i].d); - const float d1 = LM_GGML_FP16_TO_FP32(y[i].d); + for (; ib < nb; ++ib) { + const float d0 = LM_GGML_FP16_TO_FP32(x[ib].d); + const float d1 = LM_GGML_FP16_TO_FP32(y[ib].d); - summs += LM_GGML_FP16_TO_FP32(x[i].m) * LM_GGML_FP16_TO_FP32(y[i].s); + summs += LM_GGML_FP16_TO_FP32(x[ib].m) * LM_GGML_FP16_TO_FP32(y[ib].s); const __m256 d0v = __lasx_xvreplfr2vr_s( d0 ); const __m256 d1v = __lasx_xvreplfr2vr_s( d1 ); @@ -4575,8 +4462,8 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v ); // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i qx = bytes_from_nibbles_32(x[i].qs); - const __m256i qy = __lasx_xvld( (const __m256i *)y[i].qs, 0); + const __m256i qx = bytes_from_nibbles_32(x[ib].qs); + const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0); const __m256 xy = mul_sum_us8_pairs_float(qx, qy); @@ -4584,33 +4471,31 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void acc = __lasx_xvfmadd_s( d0d1, xy, acc ); } - *s = hsum_float_8(acc) + summs; - -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { + sumf = hsum_float_8(acc) + summs; +#endif + for (; ib < nb; ++ib) { int sumi = 0; for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[i].qs[j] & 0x0F); - const int v1 = (x[i].qs[j] >> 4); + const int v0 = (x[ib].qs[j] & 0x0F); + const int v1 = (x[ib].qs[j] >> 4); - sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + sumi += (v0 * y[ib].qs[j]) + (v1 * y[ib].qs[j + qk/2]); } - sumf += (LM_GGML_FP16_TO_FP32(x[i].d)*LM_GGML_FP16_TO_FP32(y[i].d))*sumi + LM_GGML_FP16_TO_FP32(x[i].m)*LM_GGML_FP16_TO_FP32(y[i].s); + sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_FP16_TO_FP32(x[ib].m)*LM_GGML_FP16_TO_FP32(y[ib].s); } *s = sumf; -#endif } void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; + int ib = 0; + float sumf = 0; + assert(n % qk == 0); assert(qk == QK5_0); assert(nrc == 1); @@ -4632,13 +4517,11 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void uint64_t tmp0[4]; uint64_t tmp1[4]; - assert(nb % 2 == 0); // TODO: handle odd nb - - for (int i = 0; i < nb; i += 2) { - const block_q5_0 * restrict x0 = &x[i]; - const block_q5_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i]; - const block_q8_0 * restrict y1 = &y[i + 1]; + for (; ib + 1 < nb; ib += 2) { + const block_q5_0 * restrict x0 = &x[ib]; + const block_q5_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib]; + const block_q8_0 * restrict y1 = &y[ib + 1]; const uint8x16_t m4b = vdupq_n_u8(0x0F); @@ -4690,7 +4573,7 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #elif defined(__wasm_simd128__) v128_t sumv = wasm_f32x4_splat(0.0f); @@ -4698,9 +4581,9 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void uint64_t tmp[4]; // TODO: check if unrolling this is better - for (int i = 0; i < nb; ++i) { - const block_q5_0 * restrict x0 = &x[i]; - const block_q8_0 * restrict y0 = &y[i]; + for (; ib < nb; ++ib) { + const block_q5_0 * restrict x0 = &x[ib]; + const block_q8_0 * restrict y0 = &y[ib]; const v128_t m4b = wasm_i8x16_splat(0x0F); @@ -4750,23 +4633,23 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void wasm_f32x4_splat(LM_GGML_FP16_TO_FP32(x0->d) * LM_GGML_FP16_TO_FP32(y0->d)))); } - *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); // Main loop - for (int i = 0; i < nb; i++) { + for (; ib < nb; ++ib) { /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[i].d) * LM_GGML_FP16_TO_FP32(y[i].d)); + const __m256 d = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)); - __m256i qx = bytes_from_nibbles_32(x[i].qs); - __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); qx = _mm256_or_si256(qx, bxhi); - __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); const __m256 q = mul_sum_i8_pairs_float(qx, qy); @@ -4774,19 +4657,19 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void acc = _mm256_fmadd_ps(d, q, acc); } - *s = hsum_float_8(acc); + sumf = hsum_float_8(acc); #elif defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); __m128i mask = _mm_set1_epi8((char)0xF0); // Main loop - for (int i = 0; i < nb; i++) { + for (; ib < nb; ++ib) { /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[i].d) * LM_GGML_FP16_TO_FP32(y[i].d)); + const __m256 d = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)); - __m256i bx_0 = bytes_from_nibbles_32(x[i].qs); - const __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs); + const __m256i bxhi = bytes_from_bits_32(x[ib].qh); __m128i bxhil = _mm256_castsi256_si128(bxhi); __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); bxhil = _mm_andnot_si128(bxhil, mask); @@ -4797,7 +4680,7 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void bxh = _mm_or_si128(bxh, bxhih); bx_0 = MM256_SET_M128I(bxh, bxl); - const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs); + const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs); const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0); @@ -4805,10 +4688,8 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); } - *s = hsum_float_8(acc); + sumf = hsum_float_8(acc); #elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - uint32_t qh; size_t vl = __riscv_vsetvl_e8m1(qk/2); @@ -4820,8 +4701,8 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); - for (int i = 0; i < nb; i++) { - memcpy(&qh, x[i].qh, sizeof(uint32_t)); + for (; ib < nb; ++ib) { + memcpy(&qh, x[ib].qh, sizeof(uint32_t)); // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); @@ -4840,10 +4721,10 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); // load - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); @@ -4870,8 +4751,6 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void sumf += (LM_GGML_FP16_TO_FP32(x[i].d)*LM_GGML_FP16_TO_FP32(y[i].d)) * sumi; } - *s = sumf; - #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); const vector unsigned char v4 = vec_splats((unsigned char)4); @@ -4879,27 +4758,27 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void vector float vsumf0 = vec_splats(0.0f); #pragma GCC unroll 4 - for (int i = 0; i < nb; ++i) { - __builtin_prefetch(x[i].qs, 0, 1); - __builtin_prefetch(y[i].qs, 0, 1); + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[i].d)); + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); vector float vd = vec_mul(vxd, vyd); - vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[i].qh[0]]), (uint64_t)(table_b2b_1[x[i].qh[1]])}; - vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[i].qh[2]]), (uint64_t)(table_b2b_1[x[i].qh[3]])}; + vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[ib].qh[0]]), (uint64_t)(table_b2b_1[x[ib].qh[1]])}; + vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[ib].qh[2]]), (uint64_t)(table_b2b_1[x[ib].qh[3]])}; vector signed char qh0 = (vector signed char)aux64x2_0; vector signed char qh1 = (vector signed char)aux64x2_1; - vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs); + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); vector signed char q5x0 = vec_sub(vec_and (qxs, lowMask), qh0); vector signed char q5x1 = vec_sub(vec_sr(qxs, v4), qh1); - vector signed char q8y0 = vec_xl( 0, y[i].qs); - vector signed char q8y1 = vec_xl( 16, y[i].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl( 16, y[ib].qs); vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0)); vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1)); @@ -4914,23 +4793,23 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - *s = vec_extract(vsumf0, 0); + sumf = vec_extract(vsumf0, 0); #elif defined(__loongarch_asx) // Initialize accumulator with zeros __m256 acc = (__m256)__lasx_xvldi(0); // Main loop - for (int i = 0; i < nb; i++) { + for (; ib < nb; ++ib) { /* Compute combined scale for the block */ - const __m256 d = __lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(x[i].d) * LM_GGML_FP16_TO_FP32(y[i].d)); //FIXME + const __m256 d = __lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)); //FIXME - __m256i qx = bytes_from_nibbles_32(x[i].qs); - __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0)); qx = __lasx_xvor_v(qx, bxhi); - __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0); + __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); const __m256 q = mul_sum_i8_pairs_float(qx, qy); @@ -4938,15 +4817,11 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void acc = __lasx_xvfmadd_s(d, q, acc); } - *s = hsum_float_8(acc); - -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { + sumf = hsum_float_8(acc); +#endif + for (; ib < nb; ++ib) { uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); + memcpy(&qh, x[ib].qh, sizeof(qh)); int sumi = 0; @@ -4954,23 +4829,25 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); - const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; - const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; + const int32_t x0 = ((x[ib].qs[j] & 0x0F) | xh_0) - 16; + const int32_t x1 = ((x[ib].qs[j] >> 4) | xh_1) - 16; - sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); + sumi += (x0 * y[ib].qs[j]) + (x1 * y[ib].qs[j + qk/2]); } - sumf += (LM_GGML_FP16_TO_FP32(x[i].d)*LM_GGML_FP16_TO_FP32(y[i].d)) * sumi; + sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d)) * sumi; } *s = sumf; -#endif } void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_1; const int nb = n / qk; + int ib = 0; + float sumf = 0; + assert(n % qk == 0); assert(qk == QK5_1); assert(nrc == 1); @@ -4995,13 +4872,11 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void uint64_t tmp0[4]; uint64_t tmp1[4]; - assert(nb % 2 == 0); // TODO: handle odd nb - - for (int i = 0; i < nb; i += 2) { - const block_q5_1 * restrict x0 = &x[i]; - const block_q5_1 * restrict x1 = &x[i + 1]; - const block_q8_1 * restrict y0 = &y[i]; - const block_q8_1 * restrict y1 = &y[i + 1]; + for (; ib + 1 < nb; ib += 2) { + const block_q5_1 * restrict x0 = &x[ib]; + const block_q5_1 * restrict x1 = &x[ib + 1]; + const block_q8_1 * restrict y0 = &y[ib]; + const block_q8_1 * restrict y1 = &y[ib + 1]; const uint8x16_t m4b = vdupq_n_u8(0x0F); @@ -5056,7 +4931,7 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; #elif defined(__wasm_simd128__) v128_t sumv = wasm_f32x4_splat(0.0f); @@ -5066,9 +4941,9 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void uint64_t tmp[4]; // TODO: check if unrolling this is better - for (int i = 0; i < nb; ++i) { - const block_q5_1 * restrict x0 = &x[i]; - const block_q8_1 * restrict y0 = &y[i]; + for (; ib < nb; ++ib) { + const block_q5_1 * restrict x0 = &x[ib]; + const block_q8_1 * restrict y0 = &y[ib]; summs += LM_GGML_FP16_TO_FP32(x0->m) * LM_GGML_FP16_TO_FP32(y0->s); @@ -5120,8 +4995,8 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void wasm_f32x4_splat(LM_GGML_FP16_TO_FP32(x0->d) * LM_GGML_FP16_TO_FP32(y0->d)))); } - *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -5129,25 +5004,25 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void float summs = 0.0f; // Main loop - for (int i = 0; i < nb; i++) { - const __m256 dx = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[i].d)); + for (; ib < nb; ++ib) { + const __m256 dx = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ib].d)); - summs += LM_GGML_FP16_TO_FP32(x[i].m) * LM_GGML_FP16_TO_FP32(y[i].s); + summs += LM_GGML_FP16_TO_FP32(x[ib].m) * LM_GGML_FP16_TO_FP32(y[ib].s); - __m256i qx = bytes_from_nibbles_32(x[i].qs); - __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); qx = _mm256_or_si256(qx, bxhi); - const __m256 dy = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[i].d)); - const __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); + const __m256 dy = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib].d)); + const __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); const __m256 q = mul_sum_us8_pairs_float(qx, qy); acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); } - *s = hsum_float_8(acc) + summs; + sumf = hsum_float_8(acc) + summs; #elif defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -5156,13 +5031,13 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void float summs = 0.0f; // Main loop - for (int i = 0; i < nb; i++) { - const __m256 dx = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[i].d)); + for (; ib < nb; ++ib) { + const __m256 dx = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ib].d)); - summs += LM_GGML_FP16_TO_FP32(x[i].m) * LM_GGML_FP16_TO_FP32(y[i].s); + summs += LM_GGML_FP16_TO_FP32(x[ib].m) * LM_GGML_FP16_TO_FP32(y[ib].s); - __m256i bx_0 = bytes_from_nibbles_32(x[i].qs); - const __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs); + const __m256i bxhi = bytes_from_bits_32(x[ib].qh); __m128i bxhil = _mm256_castsi256_si128(bxhi); __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); bxhil = _mm_and_si128(bxhil, mask); @@ -5173,18 +5048,16 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void bxh = _mm_or_si128(bxh, bxhih); bx_0 = MM256_SET_M128I(bxh, bxl); - const __m256 dy = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[i].d)); - const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs); + const __m256 dy = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib].d)); + const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs); const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0); acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); } - *s = hsum_float_8(acc) + summs; + sumf = hsum_float_8(acc) + summs; #elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - uint32_t qh; size_t vl = __riscv_vsetvl_e8m1(qk/2); @@ -5193,8 +5066,8 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); - for (int i = 0; i < nb; i++) { - memcpy(&qh, x[i].qh, sizeof(uint32_t)); + for (; ib < nb; ++ib) { + memcpy(&qh, x[ib].qh, sizeof(uint32_t)); // load qh vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); @@ -5216,10 +5089,10 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); // load - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); @@ -5240,50 +5113,47 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - sumf += (LM_GGML_FP16_TO_FP32(x[i].d)*LM_GGML_FP16_TO_FP32(y[i].d))*sumi + LM_GGML_FP16_TO_FP32(x[i].m)*LM_GGML_FP16_TO_FP32(y[i].s); + sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_FP16_TO_FP32(x[ib].m)*LM_GGML_FP16_TO_FP32(y[ib].s); } - *s = sumf; - #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); const vector unsigned char v4 = vec_splats((unsigned char)0x4); vector float vsumf0 = vec_splats(0.0f); #pragma GCC unroll 4 - for (int i = 0; i < nb; ++i) { - __builtin_prefetch(x[i].qs, 0, 1); - __builtin_prefetch(y[i].qs, 0, 1); + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[i].d)); + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); vector float vd = vec_mul(vxd, vyd); - vector float vxmin = vec_splats(LM_GGML_FP16_TO_FP32(x[i].m)); - vector float vys = {LM_GGML_FP16_TO_FP32(y[i].s), 0.f, 0.f, 0.f}; + vector float vxmin = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].m)); + vector float vys = {LM_GGML_FP16_TO_FP32(y[ib].s), 0.f, 0.f, 0.f}; vsumf0 = vec_madd(vxmin, vys, vsumf0); - vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[i].qh[0]]), (uint64_t)(table_b2b_0[x[i].qh[1]])}; - vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[i].qh[2]]), (uint64_t)(table_b2b_0[x[i].qh[3]])}; + vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[ib].qh[0]]), (uint64_t)(table_b2b_0[x[ib].qh[1]])}; + vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[ib].qh[2]]), (uint64_t)(table_b2b_0[x[ib].qh[3]])}; vector signed char qh0 = (vector signed char)aux64x2_0; vector signed char qh1 = (vector signed char)aux64x2_1; - vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs); - - vector signed char q5x0 = vec_or(vec_and(qxs, lowMask), qh0); - vector signed char q5x1 = vec_or(vec_sr(qxs, v4), qh1); + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - vector signed char q8y0 = vec_xl( 0, y[i].qs); - vector signed char q8y1 = vec_xl( 16, y[i].qs); + vector unsigned char q5x0 = (vector unsigned char)vec_or(vec_and(qxs, lowMask), qh0); + vector unsigned char q5x1 = (vector unsigned char)vec_or(vec_sr(qxs, v4), qh1); - vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1)); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl( 16, y[ib].qs); - qv0 = vec_add(qv0, qv1); + vector signed int vsumi0 = v0; - vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); + vsumi0 = vec_msum(q8y0, q5x0, vsumi0); + vsumi0 = vec_msum(q8y1, q5x1, vsumi0); vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); } @@ -5291,7 +5161,7 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - *s = vec_extract(vsumf0, 0); + sumf = vec_extract(vsumf0, 0); #elif defined(__loongarch_asx) // Initialize accumulator with zeros @@ -5300,33 +5170,29 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void float summs = 0.0f; // Main loop - for (int i = 0; i < nb; i++) { - const __m256 dx = __lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(x[i].d)); + for (; ib < nb; ++ib) { + const __m256 dx = __lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(x[ib].d)); - summs += LM_GGML_FP16_TO_FP32(x[i].m) * LM_GGML_FP16_TO_FP32(y[i].s); + summs += LM_GGML_FP16_TO_FP32(x[ib].m) * LM_GGML_FP16_TO_FP32(y[ib].s); - __m256i qx = bytes_from_nibbles_32(x[i].qs); - __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10)); qx = __lasx_xvor_v(qx, bxhi); - const __m256 dy = __lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(y[i].d)); - const __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0); + const __m256 dy = __lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(y[ib].d)); + const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); const __m256 q = mul_sum_us8_pairs_float(qx, qy); acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc); } - *s = hsum_float_8(acc) + summs; - -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { + sumf = hsum_float_8(acc) + summs; +#endif + for (; ib < nb; ++ib) { uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); + memcpy(&qh, x[ib].qh, sizeof(qh)); int sumi = 0; @@ -5334,17 +5200,16 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0; - const int32_t x1 = (x[i].qs[j] >> 4) | xh_1; + const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1; - sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); + sumi += (x0 * y[ib].qs[j]) + (x1 * y[ib].qs[j + qk/2]); } - sumf += (LM_GGML_FP16_TO_FP32(x[i].d)*LM_GGML_FP16_TO_FP32(y[i].d))*sumi + LM_GGML_FP16_TO_FP32(x[i].m)*LM_GGML_FP16_TO_FP32(y[i].s); + sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_FP16_TO_FP32(x[ib].m)*LM_GGML_FP16_TO_FP32(y[ib].s); } *s = sumf; -#endif } void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { @@ -5421,42 +5286,44 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void return; } #endif + + int ib = 0; + float sumf = 0; + #if defined(__ARM_FEATURE_SVE) - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); + if (svcntb() == QK8_0) { + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); - assert(nb % 2 == 0); // TODO: handle odd nb + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; - for (int i = 0; i < nb; i += 2) { - const block_q8_0 * restrict x0 = &x[i + 0]; - const block_q8_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; + // load x + const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); + const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); - // load x - const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); - const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); } - - *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); #elif defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); - assert(nb % 2 == 0); // TODO: handle odd nb - - for (int i = 0; i < nb; i += 2) { - const block_q8_0 * restrict x0 = &x[i + 0]; - const block_q8_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; const int8x16_t x0_0 = vld1q_s8(x0->qs); const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); @@ -5478,17 +5345,17 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void lm_ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #elif defined(__AVX2__) || defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); // Main loop - for (int i = 0; i < nb; ++i) { + for (; ib < nb; ++ib) { // Compute combined scale for the block - const __m256 d = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[i].d) * LM_GGML_FP16_TO_FP32(y[i].d)); - __m256i qx = _mm256_loadu_si256((const __m256i *)x[i].qs); - __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); + const __m256 d = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)); + __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); const __m256 q = mul_sum_i8_pairs_float(qx, qy); @@ -5500,15 +5367,14 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void #endif } - *s = hsum_float_8(acc); + sumf = hsum_float_8(acc); #elif defined(__riscv_v_intrinsic) - float sumf = 0.0; size_t vl = __riscv_vsetvl_e8m1(qk); - for (int i = 0; i < nb; i++) { + for (; ib < nb; ++ib) { // load elements - vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[i].qs, vl); - vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[i].qs, vl); + vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl); + vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl); @@ -5517,40 +5383,38 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); - sumf += sumi*(LM_GGML_FP16_TO_FP32(x[i].d)*LM_GGML_FP16_TO_FP32(y[i].d)); + sumf += sumi*(LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d)); } - - *s = sumf; - #elif defined(__POWER9_VECTOR__) + const vector signed int v0 = vec_splats((int32_t)0); vector float vsumf0 = vec_splats(0.0f); -#pragma GCC unroll 4 - for (int i = 0; i < nb; i++) { - __builtin_prefetch(x[i].qs, 0, 1); - __builtin_prefetch(y[i].qs, 0, 1); +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[i].d)); + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); vector float vd = vec_mul(vxd, vyd); - vector signed char q8x0 = vec_xl( 0, x[i].qs); - vector signed char q8x1 = vec_xl(16, x[i].qs); - vector signed char q8y0 = vec_xl( 0, y[i].qs); - vector signed char q8y1 = vec_xl(16, y[i].qs); + vector signed char q8x0 = vec_xl( 0, x[ib].qs); + vector signed char q8x1 = vec_xl(16, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); vector signed short qv0 = vec_mule(q8x0, q8y0); vector signed short qv1 = vec_mulo(q8x0, q8y0); vector signed short qv2 = vec_mule(q8x1, q8y1); vector signed short qv3 = vec_mulo(q8x1, q8y1); - vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackh(qv1)); - vector signed int vsumi1 = vec_add(vec_unpackl(qv0), vec_unpackl(qv1)); - vector signed int vsumi2 = vec_add(vec_unpackh(qv2), vec_unpackh(qv3)); - vector signed int vsumi3 = vec_add(vec_unpackl(qv2), vec_unpackl(qv3)); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; - vsumi0 = vec_add(vsumi0, vsumi2); - vsumi1 = vec_add(vsumi1, vsumi3); + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi1 = vec_sum4s(qv1, vsumi1); + vsumi0 = vec_sum4s(qv2, vsumi0); + vsumi1 = vec_sum4s(qv3, vsumi1); vsumi0 = vec_add(vsumi0, vsumi1); @@ -5560,18 +5424,18 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - *s = vec_extract(vsumf0, 0); + sumf = vec_extract(vsumf0, 0); #elif defined(__loongarch_asx) // Initialize accumulator with zeros __m256 acc = (__m256)__lasx_xvldi(0); // Main loop - for (int i = 0; i < nb; ++i) { + for (; ib < nb; ++ib) { // Compute combined scale for the block - const __m256 d = __lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(x[i].d) * LM_GGML_FP16_TO_FP32(y[i].d)); - __m256i qx = __lasx_xvld((const __m256i *)x[i].qs, 0); - __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0); + const __m256 d = __lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)); + __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0); + __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); const __m256 q = mul_sum_i8_pairs_float(qx, qy); @@ -5579,24 +5443,19 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void acc = __lasx_xvfmadd_s( d, q, acc ); } - *s = hsum_float_8(acc); - -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { + sumf = hsum_float_8(acc); +#endif + for (; ib < nb; ++ib) { int sumi = 0; for (int j = 0; j < qk; j++) { - sumi += x[i].qs[j]*y[i].qs[j]; + sumi += x[ib].qs[j]*y[ib].qs[j]; } - sumf += sumi*(LM_GGML_FP16_TO_FP32(x[i].d)*LM_GGML_FP16_TO_FP32(y[i].d)); + sumf += sumi*(LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d)); } *s = sumf; -#endif } void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { @@ -5938,6 +5797,7 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0x3); const vector signed char lowScaleMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); const vector unsigned char v2 = vec_splats((unsigned char)0x2); const vector unsigned char v6 = vec_splats((unsigned char)0x6); const vector unsigned char v4 = vec_splats((unsigned char)0x4); @@ -5975,15 +5835,17 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; for (int j = 0; j < QK_K/128; ++j) { __builtin_prefetch(q2, 0, 1); @@ -5993,14 +5855,14 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void vector signed char qxs1 = (vector signed char)vec_xl(16, q2); q2 += 32; - vector signed char q2x00 = vec_and(qxs0, lowMask); - vector signed char q2x01 = vec_and(vec_sr(qxs0, v2), lowMask); - vector signed char q2x02 = vec_and(vec_sr(qxs0, v4), lowMask); - vector signed char q2x03 = vec_and(vec_sr(qxs0, v6), lowMask); - vector signed char q2x10 = vec_and(qxs1, lowMask); - vector signed char q2x11 = vec_and(vec_sr(qxs1, v2), lowMask); - vector signed char q2x12 = vec_and(vec_sr(qxs1, v4), lowMask); - vector signed char q2x13 = vec_and(vec_sr(qxs1, v6), lowMask); + vector unsigned char q2x00 = (vector unsigned char)vec_and(qxs0, lowMask); + vector unsigned char q2x01 = (vector unsigned char)vec_and(vec_sr(qxs0, v2), lowMask); + vector unsigned char q2x02 = (vector unsigned char)vec_and(vec_sr(qxs0, v4), lowMask); + vector unsigned char q2x03 = (vector unsigned char)vec_and(vec_sr(qxs0, v6), lowMask); + vector unsigned char q2x10 = (vector unsigned char)vec_and(qxs1, lowMask); + vector unsigned char q2x11 = (vector unsigned char)vec_and(vec_sr(qxs1, v2), lowMask); + vector unsigned char q2x12 = (vector unsigned char)vec_and(vec_sr(qxs1, v4), lowMask); + vector unsigned char q2x13 = (vector unsigned char)vec_and(vec_sr(qxs1, v6), lowMask); vector signed char q8y00 = vec_xl( 0, q8); vector signed char q8y10 = vec_xl( 16, q8); @@ -6012,45 +5874,36 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void vector signed char q8y13 = vec_xl(112, q8); q8 += 128; - vector signed short qv0 = vec_add(vec_mule(q2x00, q8y00), vec_mulo(q2x00, q8y00)); - vector signed short qv1 = vec_add(vec_mule(q2x01, q8y01), vec_mulo(q2x01, q8y01)); - vector signed short qv2 = vec_add(vec_mule(q2x02, q8y02), vec_mulo(q2x02, q8y02)); - vector signed short qv3 = vec_add(vec_mule(q2x03, q8y03), vec_mulo(q2x03, q8y03)); - vector signed short qv4 = vec_add(vec_mule(q2x10, q8y10), vec_mulo(q2x10, q8y10)); - vector signed short qv5 = vec_add(vec_mule(q2x11, q8y11), vec_mulo(q2x11, q8y11)); - vector signed short qv6 = vec_add(vec_mule(q2x12, q8y12), vec_mulo(q2x12, q8y12)); - vector signed short qv7 = vec_add(vec_mule(q2x13, q8y13), vec_mulo(q2x13, q8y13)); - - vector signed short vscales_h = vec_unpackh(vscales); - vector signed short vs0 = vec_splat(vscales_h, 0); - vector signed short vs1 = vec_splat(vscales_h, 1); - vector signed short vs2 = vec_splat(vscales_h, 2); - vector signed short vs3 = vec_splat(vscales_h, 3); - vector signed short vs4 = vec_splat(vscales_h, 4); - vector signed short vs5 = vec_splat(vscales_h, 5); - vector signed short vs6 = vec_splat(vscales_h, 6); - vector signed short vs7 = vec_splat(vscales_h, 7); + vector signed int qv0 = vec_msum(q8y00, q2x00, v0); + vector signed int qv1 = vec_msum(q8y01, q2x01, v0); + vector signed int qv2 = vec_msum(q8y02, q2x02, v0); + vector signed int qv3 = vec_msum(q8y03, q2x03, v0); + vector signed int qv4 = vec_msum(q8y10, q2x10, v0); + vector signed int qv5 = vec_msum(q8y11, q2x11, v0); + vector signed int qv6 = vec_msum(q8y12, q2x12, v0); + vector signed int qv7 = vec_msum(q8y13, q2x13, v0); + + vector signed short vscales_07 = vec_unpackh(vscales); + vector signed int vscales_03 = vec_unpackh(vscales_07); + vector signed int vscales_47 = vec_unpackl(vscales_07); + vector signed int vs0 = vec_splat(vscales_03, 0); + vector signed int vs1 = vec_splat(vscales_03, 1); + vector signed int vs2 = vec_splat(vscales_03, 2); + vector signed int vs3 = vec_splat(vscales_03, 3); + vector signed int vs4 = vec_splat(vscales_47, 0); + vector signed int vs5 = vec_splat(vscales_47, 1); + vector signed int vs6 = vec_splat(vscales_47, 2); + vector signed int vs7 = vec_splat(vscales_47, 3); vscales = vec_sld(vscales, vscales, 8); - qv0 = vec_mul(qv0, vs0); - qv1 = vec_mul(qv1, vs2); - qv2 = vec_mul(qv2, vs4); - qv3 = vec_mul(qv3, vs6); - - qv0 = vec_madd(qv4, vs1, qv0); - qv1 = vec_madd(qv5, vs3, qv1); - qv2 = vec_madd(qv6, vs5, qv2); - qv3 = vec_madd(qv7, vs7, qv3); - - vsumi0 = vec_add(vec_unpackh(qv0), vsumi0); - vsumi1 = vec_add(vec_unpackh(qv1), vsumi1); - vsumi2 = vec_add(vec_unpackh(qv2), vsumi2); - vsumi3 = vec_add(vec_unpackh(qv3), vsumi3); - - vsumi4 = vec_add(vec_unpackl(qv0), vsumi4); - vsumi5 = vec_add(vec_unpackl(qv1), vsumi5); - vsumi6 = vec_add(vec_unpackl(qv2), vsumi6); - vsumi7 = vec_add(vec_unpackl(qv3), vsumi7); + vsumi0 = vec_add(vec_mul(qv0, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv1, vs2), vsumi1); + vsumi2 = vec_add(vec_mul(qv2, vs4), vsumi2); + vsumi3 = vec_add(vec_mul(qv3, vs6), vsumi3); + vsumi4 = vec_add(vec_mul(qv4, vs1), vsumi4); + vsumi5 = vec_add(vec_mul(qv5, vs3), vsumi5); + vsumi6 = vec_add(vec_mul(qv6, vs5), vsumi6); + vsumi7 = vec_add(vec_mul(qv7, vs7), vsumi7); } vsumi0 = vec_add(vsumi0, vsumi4); @@ -6088,6 +5941,7 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; + const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0); const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4); const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4); @@ -6640,6 +6494,9 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0x3); + const vector signed char lowMask1 = vec_splats((int8_t)0xf); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); const vector signed char v1 = vec_splats((signed char)0x1); const vector unsigned char v2 = vec_splats((unsigned char)0x2); const vector unsigned char v3 = vec_splats((unsigned char)0x3); @@ -6657,30 +6514,33 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void vector float vyd = vec_splats(y[i].d); vector float vd = vec_mul(vxd, vyd); - uint32_t aux[3]; - uint32_t utmp[4]; + UNUSED(kmask1); + UNUSED(kmask2); - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(u0, lowMask1); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = (vector signed char)vec_mergeh((vector signed int)u2, (vector signed int)vec_sr(u2, v2)); + vector signed char u30 = vec_sl(vec_and(u3, lowMask), v4); + vector signed char u31 = vec_and(u3, lowMask2); + + u1 = vec_or(u1, u30); + u2 = vec_or(vec_sr(u0, v4), u31); - vector signed char vscales = (vector signed char)vec_xl( 0, utmp); + vector signed char vscales = (vector signed char)vec_mergeh((vector signed long long)u1, (vector signed long long)u2); vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask); vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask); vscales = vec_sub(vscales, off); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); - + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; const uint8_t * restrict q3 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -6754,23 +6614,14 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12)); vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13)); - vector signed int vsum0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0)); - vector signed int vsum1 = vec_add(vec_mule(qv01, vs2), vec_mulo(qv01, vs2)); - vector signed int vsum2 = vec_add(vec_mule(qv02, vs4), vec_mulo(qv02, vs4)); - vector signed int vsum3 = vec_add(vec_mule(qv03, vs6), vec_mulo(qv03, vs6)); - vector signed int vsum4 = vec_add(vec_mule(qv10, vs1), vec_mulo(qv10, vs1)); - vector signed int vsum5 = vec_add(vec_mule(qv11, vs3), vec_mulo(qv11, vs3)); - vector signed int vsum6 = vec_add(vec_mule(qv12, vs5), vec_mulo(qv12, vs5)); - vector signed int vsum7 = vec_add(vec_mule(qv13, vs7), vec_mulo(qv13, vs7)); - - vsumi0 = vec_add(vsum0, vsumi0); - vsumi1 = vec_add(vsum1, vsumi1); - vsumi2 = vec_add(vsum2, vsumi2); - vsumi3 = vec_add(vsum3, vsumi3); - vsumi4 = vec_add(vsum4, vsumi4); - vsumi5 = vec_add(vsum5, vsumi5); - vsumi6 = vec_add(vsum6, vsumi6); - vsumi7 = vec_add(vsum7, vsumi7); + vsumi0 = vec_msum(qv00, vs0, vsumi0); + vsumi1 = vec_msum(qv01, vs2, vsumi1); + vsumi2 = vec_msum(qv02, vs4, vsumi2); + vsumi3 = vec_msum(qv03, vs6, vsumi3); + vsumi4 = vec_msum(qv10, vs1, vsumi4); + vsumi5 = vec_msum(qv11, vs3, vsumi5); + vsumi6 = vec_msum(qv12, vs5, vsumi6); + vsumi7 = vec_msum(qv13, vs7, vsumi7); } vsumi0 = vec_add(vsumi0, vsumi4); @@ -6807,6 +6658,8 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void for (int i = 0; i < nb; ++i) { const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; // Set up scales memcpy(aux, x[i].scales, 12); __m128i scales128 = lsx_set_w( @@ -6828,29 +6681,32 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void int bit = 0; int is = 0; + __m256i xvbit; - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; for (int j = 0; j < QK_K/128; ++j) { // load low 2 bits const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32; + xvbit = __lasx_xvreplgr2vr_h(bit); // prepare low and high bits const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3); - const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2); + const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); ++bit; + xvbit = __lasx_xvreplgr2vr_h(bit); const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3); - const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2); + const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); ++bit; + xvbit = __lasx_xvreplgr2vr_h(bit); const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3); - const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2); + const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); ++bit; + xvbit = __lasx_xvreplgr2vr_h(bit); const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3); - const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2); + const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); ++bit; // load Q8 quants @@ -7264,6 +7120,10 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed char lowMask1 = vec_splats((int8_t)0x3f); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v2 = vec_splats((uint8_t)2); const vector unsigned char v4 = vec_splats((unsigned char)0x4); vector float vsumf0 = vec_splats(0.0f); @@ -7282,15 +7142,24 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); vector signed short q8ysums1 = vec_xl(16, y[i].bsums); - memcpy(utmp, x[i].scales, 12); + UNUSED(kmask1); + UNUSED(kmask2); + UNUSED(kmask3); + UNUSED(utmp); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = vec_sr(u2, v4); + + vector signed char u30 = u1; + vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); + + u1 = vec_and(u0, lowMask1); + u2 = vec_or(u30, u31); + + vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); - vector signed char utmps = (vector signed char)vec_xl( 0, utmp); vector signed short vscales = vec_unpackh(utmps); vector signed short q4xmins = vec_unpackl(utmps); vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins); @@ -7306,14 +7175,10 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -7328,14 +7193,14 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void vector signed char qxs3 = (vector signed char)vec_xl(48, q4); q4 += 64; - vector signed char q4x00 = vec_and(qxs0, lowMask); - vector signed char q4x01 = vec_sr(qxs0, v4); - vector signed char q4x10 = vec_and(qxs1, lowMask); - vector signed char q4x11 = vec_sr(qxs1, v4); - vector signed char q4x20 = vec_and(qxs2, lowMask); - vector signed char q4x21 = vec_sr(qxs2, v4); - vector signed char q4x30 = vec_and(qxs3, lowMask); - vector signed char q4x31 = vec_sr(qxs3, v4); + vector unsigned char q4x00 = (vector unsigned char)vec_and(qxs0, lowMask); + vector unsigned char q4x01 = (vector unsigned char)vec_sr(qxs0, v4); + vector unsigned char q4x10 = (vector unsigned char)vec_and(qxs1, lowMask); + vector unsigned char q4x11 = (vector unsigned char)vec_sr(qxs1, v4); + vector unsigned char q4x20 = (vector unsigned char)vec_and(qxs2, lowMask); + vector unsigned char q4x21 = (vector unsigned char)vec_sr(qxs2, v4); + vector unsigned char q4x30 = (vector unsigned char)vec_and(qxs3, lowMask); + vector unsigned char q4x31 = (vector unsigned char)vec_sr(qxs3, v4); vector signed char q8y00 = vec_xl( 0, q8); vector signed char q8y10 = vec_xl( 16, q8); @@ -7347,41 +7212,33 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void vector signed char q8y31 = vec_xl(112, q8); q8 += 128; - vector signed short qv00 = vec_add(vec_mule(q4x00, q8y00), vec_mulo(q4x00, q8y00)); - vector signed short qv01 = vec_add(vec_mule(q4x01, q8y01), vec_mulo(q4x01, q8y01)); - vector signed short qv10 = vec_add(vec_mule(q4x10, q8y10), vec_mulo(q4x10, q8y10)); - vector signed short qv11 = vec_add(vec_mule(q4x11, q8y11), vec_mulo(q4x11, q8y11)); - vector signed short qv20 = vec_add(vec_mule(q4x20, q8y20), vec_mulo(q4x20, q8y20)); - vector signed short qv21 = vec_add(vec_mule(q4x21, q8y21), vec_mulo(q4x21, q8y21)); - vector signed short qv30 = vec_add(vec_mule(q4x30, q8y30), vec_mulo(q4x30, q8y30)); - vector signed short qv31 = vec_add(vec_mule(q4x31, q8y31), vec_mulo(q4x31, q8y31)); - - vector signed short vs0 = vec_splat(vscales, 0); - vector signed short vs1 = vec_splat(vscales, 1); - vector signed short vs2 = vec_splat(vscales, 2); - vector signed short vs3 = vec_splat(vscales, 3); + vector signed int qv00 = vec_msum(q8y00, q4x00, v0); + vector signed int qv01 = vec_msum(q8y01, q4x01, v0); + vector signed int qv10 = vec_msum(q8y10, q4x10, v0); + vector signed int qv11 = vec_msum(q8y11, q4x11, v0); + vector signed int qv20 = vec_msum(q8y20, q4x20, v0); + vector signed int qv21 = vec_msum(q8y21, q4x21, v0); + vector signed int qv30 = vec_msum(q8y30, q4x30, v0); + vector signed int qv31 = vec_msum(q8y31, q4x31, v0); + + vector signed int vscales_h = vec_unpackh(vscales); + vector signed int vs0 = vec_splat(vscales_h, 0); + vector signed int vs1 = vec_splat(vscales_h, 1); + vector signed int vs2 = vec_splat(vscales_h, 2); + vector signed int vs3 = vec_splat(vscales_h, 3); vscales = vec_sld(vscales, vscales, 8); - qv00 = vec_add(qv00, qv10); - qv10 = vec_add(qv01, qv11); - qv20 = vec_add(qv20, qv30); - qv30 = vec_add(qv21, qv31); + vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv01, vs1), vsumi1); + vsumi2 = vec_add(vec_mul(qv20, vs2), vsumi2); + vsumi3 = vec_add(vec_mul(qv21, vs3), vsumi3); - vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0); - vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1); - vsumi2 = vec_add(vec_mule(qv10, vs1), vsumi2); - vsumi3 = vec_add(vec_mulo(qv10, vs1), vsumi3); - vsumi4 = vec_add(vec_mule(qv20, vs2), vsumi4); - vsumi5 = vec_add(vec_mulo(qv20, vs2), vsumi5); - vsumi6 = vec_add(vec_mule(qv30, vs3), vsumi6); - vsumi7 = vec_add(vec_mulo(qv30, vs3), vsumi7); + vsumi0 = vec_add(vec_mul(qv10, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv11, vs1), vsumi1); + vsumi2 = vec_add(vec_mul(qv30, vs2), vsumi2); + vsumi3 = vec_add(vec_mul(qv31, vs3), vsumi3); } - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); @@ -7399,6 +7256,9 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void *s = vec_extract(vsumf0, 0); #elif defined __loongarch_asx + LM_GGML_UNUSED(kmask1); + LM_GGML_UNUSED(kmask2); + LM_GGML_UNUSED(kmask3); const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); @@ -7411,6 +7271,11 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -7450,16 +7315,17 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void __m256 vd = __lasx_xvreplfr2vr_s(d); acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc); + } acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee)); __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0); acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1); + ft_union fi; fi.i = __lsx_vpickve2gr_w(acc_m, 0); *s = hsum_float_8(acc) + fi.f ; - #else const uint8_t * scales = (const uint8_t*)&utmp[0]; @@ -7874,6 +7740,9 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed char lowMask1 = vec_splats((int8_t)0x3f); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); const vector unsigned char v1 = vec_splats((unsigned char)0x1); const vector unsigned char v2 = vec_splats((unsigned char)0x2); const vector unsigned char v3 = vec_splats((unsigned char)0x3); @@ -7892,18 +7761,27 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void vector float vxmin = vec_splats(LM_GGML_FP16_TO_FP32(x[i].dmin)); vector float vdmin = vec_mul(vxmin, vyd); - memcpy(utmp, x[i].scales, 12); + UNUSED(kmask1); + UNUSED(kmask2); + UNUSED(kmask3); + UNUSED(utmp); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = vec_sr(u2, v4); + + vector signed char u30 = u1; + vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); + + u1 = vec_and(u0, lowMask1); + u2 = vec_or(u30, u31); + + vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); vector signed short q8ysums1 = vec_xl(16, y[i].bsums); - vector signed char utmps = (vector signed char)vec_xl( 0, utmp); vector signed short vscales = vec_unpackh(utmps); vector signed short q5xmins = vec_unpackl(utmps); @@ -7923,10 +7801,10 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh); vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; const uint8_t * restrict q5 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -7951,10 +7829,10 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void qxhs0 = vec_sr(qxhs0, v2); qxhs1 = vec_sr(qxhs1, v2); - vector signed char q5x00 = vec_or(q5h00, qxs00); - vector signed char q5x01 = vec_or(q5h01, qxs01); - vector signed char q5x10 = vec_or(q5h10, qxs10); - vector signed char q5x11 = vec_or(q5h11, qxs11); + vector unsigned char q5x00 = (vector unsigned char)vec_or(q5h00, qxs00); + vector unsigned char q5x01 = (vector unsigned char)vec_or(q5h01, qxs01); + vector unsigned char q5x10 = (vector unsigned char)vec_or(q5h10, qxs10); + vector unsigned char q5x11 = (vector unsigned char)vec_or(q5h11, qxs11); vector signed char q8y00 = vec_xl( 0, q8); vector signed char q8y10 = vec_xl(16, q8); @@ -7962,22 +7840,20 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void vector signed char q8y11 = vec_xl(48, q8); q8 += 64; - vector signed short qv00 = vec_add(vec_mule(q5x00, q8y00), vec_mulo(q5x00, q8y00)); - vector signed short qv01 = vec_add(vec_mule(q5x01, q8y01), vec_mulo(q5x01, q8y01)); - vector signed short qv10 = vec_add(vec_mule(q5x10, q8y10), vec_mulo(q5x10, q8y10)); - vector signed short qv11 = vec_add(vec_mule(q5x11, q8y11), vec_mulo(q5x11, q8y11)); + vector signed int qv00 = vec_msum(q8y00, q5x00, v0); + vector signed int qv01 = vec_msum(q8y01, q5x01, v0); + vector signed int qv10 = vec_msum(q8y10, q5x10, v0); + vector signed int qv11 = vec_msum(q8y11, q5x11, v0); - vector signed short vs0 = vec_splat(vscales, 0); - vector signed short vs1 = vec_splat(vscales, 1); + vector signed int vscales_h = vec_unpackh(vscales); + vector signed int vs0 = vec_splat(vscales_h, 0); + vector signed int vs1 = vec_splat(vscales_h, 1); vscales = vec_sld(vscales, vscales, 12); - qv00 = vec_add(qv00, qv10); - qv01 = vec_add(qv01, qv11); - - vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0); - vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1); - vsumi2 = vec_add(vec_mule(qv01, vs1), vsumi2); - vsumi3 = vec_add(vec_mulo(qv01, vs1), vsumi3); + vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv10, vs0), vsumi1); + vsumi2 = vec_add(vec_mul(qv01, vs1), vsumi2); + vsumi3 = vec_add(vec_mul(qv11, vs1), vsumi3); } vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); @@ -7997,6 +7873,9 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void *s = vec_extract(vsumf0, 0); #elif defined __loongarch_asx + LM_GGML_UNUSED(kmask1); + LM_GGML_UNUSED(kmask2); + LM_GGML_UNUSED(kmask3); const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); const __m128i mzero = __lsx_vldi(0); @@ -8015,6 +7894,11 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); @@ -8033,6 +7917,7 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void __m256i sumi = __lasx_xvldi(0); int bit = 0; + __m256i xvbit; for (int j = 0; j < QK_K/64; ++j) { @@ -8041,13 +7926,15 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32; + xvbit = __lasx_xvreplgr2vr_h(bit++); const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4); - const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvand_v(hbits, hmask), bit++), 4); + const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0); hmask = __lasx_xvslli_h(hmask, 1); + xvbit = __lasx_xvreplgr2vr_h(bit++); const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4); - const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvand_v(hbits, hmask), bit++), 4); + const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1); hmask = __lasx_xvslli_h(hmask, 1); @@ -8061,10 +7948,12 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void p16_1 = lasx_madd_h(scale_1, p16_1); sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); + } __m256 vd = __lasx_xvreplfr2vr_s(d); acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc); + } *s = hsum_float_8(acc) + summs; @@ -8525,6 +8414,7 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); const vector unsigned char v2 = vec_splats((unsigned char)0x2); const vector unsigned char v3 = vec_splats((unsigned char)0x3); const vector unsigned char v4 = vec_splats((unsigned char)0x4); @@ -8541,14 +8431,14 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void vector float vyd = vec_splats(y[i].d); vector float vd = vec_mul(vxd, vyd); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; const uint8_t * restrict q6 = x[i].ql; const uint8_t * restrict qh = x[i].qh; @@ -8628,23 +8518,14 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void vector signed short vs6 = vec_splat(vscales, 6); vector signed short vs7 = vec_splat(vscales, 7); - vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0); - vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1); - vsumi2 = vec_add(vec_mule(qv01, vs4), vsumi2); - vsumi3 = vec_add(vec_mulo(qv01, vs4), vsumi3); - vsumi4 = vec_add(vec_mule(qv10, vs1), vsumi4); - vsumi5 = vec_add(vec_mulo(qv10, vs1), vsumi5); - vsumi6 = vec_add(vec_mule(qv11, vs5), vsumi6); - vsumi7 = vec_add(vec_mulo(qv11, vs5), vsumi7); - - vsumi0 = vec_add(vec_mule(qv20, vs2), vsumi0); - vsumi1 = vec_add(vec_mulo(qv20, vs2), vsumi1); - vsumi2 = vec_add(vec_mule(qv21, vs6), vsumi2); - vsumi3 = vec_add(vec_mulo(qv21, vs6), vsumi3); - vsumi4 = vec_add(vec_mule(qv30, vs3), vsumi4); - vsumi5 = vec_add(vec_mulo(qv30, vs3), vsumi5); - vsumi6 = vec_add(vec_mule(qv31, vs7), vsumi6); - vsumi7 = vec_add(vec_mulo(qv31, vs7), vsumi7); + vsumi0 = vec_msum(qv00, vs0, vsumi0); + vsumi1 = vec_msum(qv01, vs4, vsumi1); + vsumi2 = vec_msum(qv10, vs1, vsumi2); + vsumi3 = vec_msum(qv11, vs5, vsumi3); + vsumi4 = vec_msum(qv20, vs2, vsumi4); + vsumi5 = vec_msum(qv21, vs6, vsumi5); + vsumi6 = vec_msum(qv30, vs3, vsumi6); + vsumi7 = vec_msum(qv31, vs7, vsumi7); } vsumi0 = vec_add(vsumi0, vsumi4); @@ -8791,7 +8672,7 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void #endif } -#if defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx) +#if defined (__AVX__) || defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx) static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, @@ -8924,7 +8805,63 @@ void lm_ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const vo *s = 0.125f * hsum_float_8(accumf); +#elif defined(__AVX__) + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]); + const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); + const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]); + const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]); + const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); + const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]); + const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0); + const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1); + const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0); + const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1); + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const uint16_t ls1 = aux32[1] >> 28; + const uint16_t ls2 = aux32[3] >> 28; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + #elif defined(__POWER9_VECTOR__) + const vector int v0 = vec_splats((int32_t)0); vector float vsumf0 = vec_splats(0.0f); vector float vsumf1 = vec_splats(0.0f); vector float vsumf2 = vec_splats(0.0f); @@ -8937,14 +8874,10 @@ void lm_ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const vo vector float vyd = vec_splats(y[i].d); vector float vd = vec_mul(vxd, vyd); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; const uint16_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -8991,21 +8924,12 @@ void lm_ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const vo vector signed short vscales01 = vec_splats((int16_t)(2*ls0+1)); vector signed short vscales23 = vec_splats((int16_t)(2*ls1+1)); - vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0); - vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1); - vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2); - vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3); - vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4); - vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5); - vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6); - vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7); + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); } - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); @@ -9279,6 +9203,165 @@ void lm_ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const voi } *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__AVX__) + const __m128i mone = _mm_set1_epi8(1); + static const char block_sign_shuffle_mask_1[32] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, + }; + static const char block_sign_shuffle_mask_2[32] = { + 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, + }; + static const uint8_t bit_selector_mask_bytes[32] = { + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes); + const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1); + const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1); + const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1); + const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2); + const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1); + + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper); + const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1); + const __m128i m511 = _mm_set1_epi16(511); + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + uint64_t aux64; + + // somewhat hacky, but gives a significant boost in performance + __m256i aux_gindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = _mm_set1_epi64x(aux64); + stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); + const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { + + const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2); + const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1); q2 += 16; + aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511)); + + const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9); + const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9); + const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13); + const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13); + const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0); + const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1); + + const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0); + const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1); + const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0); + const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1); + + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + + const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]); + const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]); + const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]); + const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]); + const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]); + const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]); + const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); + const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]); + + // AVX2 full_signs_1 is full_sign_bits_0 here + // AVX2 full_signs_2 is full_sign_bits_1 here + __m128i signs_0, signs_1; + signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone)); + + signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone)); + + signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone)); + + signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone)); + + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const __m128i dot3_0 = _mm_maddubs_epi16(q2_3_0, q8s_3_0); + const __m128i dot3_1 = _mm_maddubs_epi16(q2_3_1, q8s_3_1); + const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0); + const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1); + + __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)); + const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)); + const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)); + const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)); + const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1)); + sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0)); + sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1)); + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1)); + sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0)); + sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1)); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + #elif defined(__loongarch_asx) const __m256i mone = __lasx_xvreplgr2vr_b(1); @@ -9397,6 +9480,7 @@ void lm_ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const voi *s = 0.125f * hsum_float_8(accumf); #elif defined(__POWER9_VECTOR__) + const vector int v0 = vec_splats((int32_t)0); vector float vsumf0 = vec_splats(0.0f); vector float vsumf1 = vec_splats(0.0f); vector float vsumf2 = vec_splats(0.0f); @@ -9409,14 +9493,10 @@ void lm_ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const voi vector float vyd = vec_splats(y[i].d); vector float vd = vec_mul(vxd, vyd); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; const uint16_t * restrict q2 = x[i].qs; const uint8_t * restrict sc = x[i].scales; @@ -9464,21 +9544,12 @@ void lm_ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const voi vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); - vsumi0 = vec_add(vec_mule(qv0, vscales0), vsumi0); - vsumi1 = vec_add(vec_mule(qv1, vscales1), vsumi1); - vsumi2 = vec_add(vec_mule(qv2, vscales2), vsumi2); - vsumi3 = vec_add(vec_mule(qv3, vscales3), vsumi3); - vsumi4 = vec_add(vec_mulo(qv0, vscales0), vsumi4); - vsumi5 = vec_add(vec_mulo(qv1, vscales1), vsumi5); - vsumi6 = vec_add(vec_mulo(qv2, vscales2), vsumi6); - vsumi7 = vec_add(vec_mulo(qv3, vscales3), vsumi7); + vsumi0 = vec_msum(qv0, vscales0, vsumi0); + vsumi1 = vec_msum(qv1, vscales1, vsumi1); + vsumi2 = vec_msum(qv2, vscales2, vsumi2); + vsumi3 = vec_msum(qv3, vscales3, vsumi3); } - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); @@ -9694,6 +9765,98 @@ void lm_ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *s = 0.125f * hsum_float_8(accumf); +#elif defined(__AVX__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1); + const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1); + const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); + const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); + + uint64_t aux64; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); + const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8); + const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8)); + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], + iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); + const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], + iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]); + const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], + iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); + const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], + iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]); + qs += 8; + + __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); + __m128i aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0); + const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1); + + aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); + aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0); + const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1); + + signs += 4; + + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0))); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1))); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0))); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1))); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + #elif defined(__POWER9_VECTOR__) static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 @@ -9701,6 +9864,8 @@ void lm_ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + const vector int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); vector float vsumf1 = vec_splats(0.0f); vector float vsumf2 = vec_splats(0.0f); @@ -9715,14 +9880,10 @@ void lm_ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void vector float vyd = vec_splats(y[i].d); vector float vd = vec_mul(vxd, vyd); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; const uint8_t * restrict q2 = x[i].qs; const uint8_t * restrict qh = x[i].qh; @@ -9782,21 +9943,12 @@ void lm_ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); - vsumi0 = vec_add(vec_mule(qv0, vscales0), vsumi0); - vsumi1 = vec_add(vec_mule(qv1, vscales1), vsumi1); - vsumi2 = vec_add(vec_mule(qv2, vscales2), vsumi2); - vsumi3 = vec_add(vec_mule(qv3, vscales3), vsumi3); - vsumi4 = vec_add(vec_mulo(qv0, vscales0), vsumi4); - vsumi5 = vec_add(vec_mulo(qv1, vscales1), vsumi5); - vsumi6 = vec_add(vec_mulo(qv2, vscales2), vsumi6); - vsumi7 = vec_add(vec_mulo(qv3, vscales3), vsumi7); + vsumi0 = vec_msum(qv0, vscales0, vsumi0); + vsumi1 = vec_msum(qv1, vscales1, vsumi1); + vsumi2 = vec_msum(qv2, vscales2, vsumi2); + vsumi3 = vec_msum(qv3, vscales3, vsumi3); } - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); @@ -10031,9 +10183,68 @@ void lm_ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const vo *s = 0.25f * hsum_float_8(accumf); +#elif defined(__AVX__) + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]); + q3 += 8; + const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]); + q3 += 8; + memcpy(aux32, gas, 8); gas += 8; + const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); + const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]); + const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]); + const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0); + const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1); + const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0); + const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1); + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const uint16_t ls1 = aux32[0] >> 28; + const uint16_t ls2 = aux32[1] >> 28; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.25f * hsum_float_8(accumf); + #elif defined(__POWER9_VECTOR__) const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const vector int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); vector float vsumf1 = vec_splats(0.0f); vector float vsumf2 = vec_splats(0.0f); @@ -10044,14 +10255,10 @@ void lm_ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const vo vector float vyd = vec_splats(y[i].d); vector float vd = vec_mul(vxd, vyd); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; const uint8_t * restrict q3 = x[i].qs; const uint32_t * restrict signs = (const uint32_t *)(x[i].qs + QK_K/4); @@ -10096,21 +10303,12 @@ void lm_ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const vo vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); - vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0); - vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1); - vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2); - vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3); - vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4); - vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5); - vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6); - vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7); + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); } - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); @@ -10393,6 +10591,112 @@ void lm_ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const voi *s = hsum_float_8(accumf); +#elif defined(__AVX__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1); + const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1); + const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); + const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); + + const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256); + const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16); + const __m128i idx_mask = _mm_set1_epi32(256); + + typedef union { + __m128i vec[4]; + uint32_t index[16]; + } index_t; + + index_t idx; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)x[i].signs; + const int8_t * restrict q8 = y[i].qs; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs); + const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp); + const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16; + idx.vec[0] = _mm_set1_epi32(qh[ib32+0]); + idx.vec[1] = idx.vec[0]; + idx.vec[2] = _mm_set1_epi32(qh[ib32+1]); + idx.vec[3] = idx.vec[2]; + + idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask); + idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask); + idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask); + idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask); + + idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0)); + idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8))); + idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1)); + idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8))); + + const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]); + const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]); + const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]); + const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]); + + __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16)); + __m128i aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0); + const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1); + + aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16)); + aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0); + const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1); + + signs += 4; + + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; + const uint16_t ls2 = x[i].scales[ib32/2] >> 4; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = hsum_float_8(accumf); + #elif defined(__POWER9_VECTOR__) static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 @@ -10400,6 +10704,8 @@ void lm_ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const voi static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + const vector int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); vector float vsumf1 = vec_splats(0.0f); vector float vsumf2 = vec_splats(0.0f); @@ -10420,14 +10726,10 @@ void lm_ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const voi const uint8_t * restrict sc = x[i].scales; const int8_t * restrict q8 = y[i].qs; - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; for (int j = 0; j < QK_K/32; j += 2) { __builtin_prefetch(q3, 0, 1); @@ -10481,21 +10783,12 @@ void lm_ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const voi vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); - vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0); - vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1); - vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2); - vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3); - vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4); - vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5); - vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6); - vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7); + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); } - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); @@ -10641,6 +10934,14 @@ void lm_ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const voi } +#if defined(__AVX__) +static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) { + const __m128i ax = _mm_sign_epi8(x, x); + const __m128i sy = _mm_sign_epi8(y, x); + return _mm_maddubs_epi16(ax, sy); +} +#endif + #if defined(__AVX2__) static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { const __m256i ax = _mm256_sign_epi8(x, x); @@ -10758,6 +11059,54 @@ void lm_ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const vo *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; +#elif defined __AVX__ + __m256 accum = _mm256_setzero_ps(); + float accum1 = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + int sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); + const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]); + const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); + const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]); + qs += 8; + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + + const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0); + const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1); + const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0); + const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1); + const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2)); + + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1)); + sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; + } + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum); + accum1 += d * sumi1; + + } + + *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; + #elif defined(__POWER9_VECTOR__) const vector unsigned char v0 = vec_splats((unsigned char)0x0); const vector unsigned short vsign = vec_splats((unsigned short)0x8000); @@ -10776,10 +11125,6 @@ void lm_ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const vo vector signed int vsumi1 = vec_splats((int32_t)0); vector signed int vsumi2 = vec_splats((int32_t)0); vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); vector signed int vsumi8 = vec_splats((int32_t)0); const uint8_t * restrict q1 = x[i].qs; @@ -10821,14 +11166,10 @@ void lm_ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const vo vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); vector signed short vscales = vec_sld(vscales23, vscales01, 8); - vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0); - vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1); - vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2); - vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3); - vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4); - vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5); - vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6); - vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7); + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); vector signed short q8ysums = vec_xl_len(qs, 8); qs += 4; @@ -10843,11 +11184,6 @@ void lm_ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const vo vsumi8 = vec_add(vec_mule(q8ysum, vscales), vsumi8); } - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); @@ -11109,6 +11445,92 @@ void lm_ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const vo *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); +#elif defined __AVX__ + const __m128i mask = _mm_set1_epi16(0x7); + const __m128i mone = _mm_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q1b_1_0 = _mm_set_epi64x( + iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]); + const __m128i q1b_1_1 = _mm_set_epi64x( + iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]); + const __m128i q1b_2_0 = _mm_set_epi64x( + iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]); + const __m128i q1b_2_1 = _mm_set_epi64x( + iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]); + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + + const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0); + const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1); + const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0); + const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1); + + const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + + const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0); + const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1); + const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0); + const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1); + + __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0); + __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3); + __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6); + __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9); + + scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone); + scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone); + scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone); + scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone); + const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1); + const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0); + const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1); + const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0); + const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1); + + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1)); + sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0)); + sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1)); + + qs += 8; qh += 4; + } + + const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_FP16_TO_FP32(scale.f16)); + + accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1); + accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2); + } + + *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); + #else int sum1[2], sum2[2], delta[4]; @@ -11173,6 +11595,9 @@ void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const voi const int nb = n / QK4_NL; + int ib = 0; + float sumf = 0; + #if defined __ARM_NEON const int8x16_t values = vld1q_s8(kvalues_iq4nl); const uint8x16_t m4b = vdupq_n_u8(0x0f); @@ -11181,16 +11606,14 @@ void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const voi int8x16x4_t q8b; int32x4_t prod_1, prod_2; - float sumf = 0; - - for (int ib = 0; ib < nb; ib += 2) { + for (; ib + 1 < nb; ib += 2) { - q4bits.val[0] = vld1q_u8(x[ib+0].qs); - q4bits.val[1] = vld1q_u8(x[ib+1].qs); - q8b.val[0] = vld1q_s8(y[ib+0].qs); - q8b.val[1] = vld1q_s8(y[ib+0].qs + 16); - q8b.val[2] = vld1q_s8(y[ib+1].qs); - q8b.val[3] = vld1q_s8(y[ib+1].qs + 16); + q4bits.val[0] = vld1q_u8(x[ib + 0].qs); + q4bits.val[1] = vld1q_u8(x[ib + 1].qs); + q8b.val[0] = vld1q_s8(y[ib + 0].qs); + q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16); + q8b.val[2] = vld1q_s8(y[ib + 1].qs); + q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16); q4b.val[0] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); q4b.val[1] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); @@ -11201,12 +11624,10 @@ void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const voi prod_2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); sumf += - LM_GGML_FP16_TO_FP32(x[ib+0].d) * LM_GGML_FP16_TO_FP32(y[ib+0].d) * vaddvq_s32(prod_1) + - LM_GGML_FP16_TO_FP32(x[ib+1].d) * LM_GGML_FP16_TO_FP32(y[ib+1].d) * vaddvq_s32(prod_2); + LM_GGML_FP16_TO_FP32(x[ib+0].d) * LM_GGML_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) + + LM_GGML_FP16_TO_FP32(x[ib+1].d) * LM_GGML_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2); } - *s = sumf; - #elif defined __AVX2__ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); @@ -11215,11 +11636,11 @@ void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const voi __m256 accum1 = _mm256_setzero_ps(); __m256 accum2 = _mm256_setzero_ps(); - for (int ib = 0; ib < nb; ib += 2) { - const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[0].qs); - const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs); - const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs); - const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs); + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs); + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs); const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), @@ -11228,19 +11649,52 @@ void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const voi const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); - accum1 = _mm256_fmadd_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[0].d)*LM_GGML_FP16_TO_FP32(x[0].d)), + accum1 = _mm256_fmadd_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib + 0].d)*LM_GGML_FP16_TO_FP32(x[ib + 0].d)), _mm256_cvtepi32_ps(p_1), accum1); - accum2 = _mm256_fmadd_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[1].d)*LM_GGML_FP16_TO_FP32(x[1].d)), + accum2 = _mm256_fmadd_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib + 1].d)*LM_GGML_FP16_TO_FP32(x[ib + 1].d)), _mm256_cvtepi32_ps(p_2), accum2); - - y += 2; - x += 2; } - *s = hsum_float_8(_mm256_add_ps(accum1, accum2)); + sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); + +#elif defined __AVX__ + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + const __m128i mone = _mm_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs); + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1); + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); + + const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); + const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); + const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); + const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); + const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); + const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); + const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); + const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); + const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone); + const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone); + const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone); + const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone); + accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib + 0].d)*LM_GGML_FP16_TO_FP32(x[ib + 0].d)), + _mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1); + accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib + 1].d)*LM_GGML_FP16_TO_FP32(x[ib + 1].d)), + _mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2); + } + + sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); const vector unsigned char v4 = vec_splats((unsigned char)0x4); vector float vsumf0 = vec_splats(0.0f); @@ -11249,7 +11703,7 @@ void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const voi const vector signed char values = vec_xl( 0, kvalues_iq4nl); #pragma GCC unroll 4 - for (int ib = 0; ib < nb; ++ib) { + for (; ib < nb; ++ib) { __builtin_prefetch(x[ib].qs, 0, 1); __builtin_prefetch(y[ib].qs, 0, 1); @@ -11271,8 +11725,11 @@ void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const voi vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); - vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); - vector signed int vsumi1 = vec_add(vec_unpackh(qv1), vec_unpackl(qv1)); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi1 = vec_sum4s(qv1, vsumi1); vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); @@ -11283,7 +11740,7 @@ void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const voi vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - *s = vec_extract(vsumf0, 0); + sumf = vec_extract(vsumf0, 0); #elif defined (__loongarch_asx) @@ -11293,11 +11750,11 @@ void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const voi __m256 accum1 = (__m256)__lasx_xvldi(0); __m256 accum2 = (__m256)__lasx_xvldi(0); - for (int ib = 0; ib < nb; ib += 2) { - const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[0].qs, 0); - const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[1].qs, 0); - const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[0].qs, 0); - const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[1].qs, 0); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0); + const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0); + const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0); + const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0); const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)), lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b))); const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)), @@ -11306,20 +11763,16 @@ void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const voi const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); const __m256i p_1 = lasx_madd_h(p16_1, mone); const __m256i p_2 = lasx_madd_h(p16_2, mone); - accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(y[0].d)*LM_GGML_FP16_TO_FP32(x[0].d)), + accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(y[ib + 0].d)*LM_GGML_FP16_TO_FP32(x[ib + 0].d)), __lasx_xvffint_s_w(p_1), accum1); - accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(y[1].d)*LM_GGML_FP16_TO_FP32(x[1].d)), + accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(y[ib + 1].d)*LM_GGML_FP16_TO_FP32(x[ib + 1].d)), __lasx_xvffint_s_w(p_2), accum2); - - y += 2; - x += 2; } - *s = hsum_float_8(__lasx_xvfadd_s(accum1, accum2)); + sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2)); -#else - float sumf = 0; - for (int ib = 0; ib < nb; ++ib) { +#endif + for (; ib < nb; ++ib) { const float d = LM_GGML_FP16_TO_FP32(y[ib].d)*LM_GGML_FP16_TO_FP32(x[ib].d); int sumi1 = 0, sumi2 = 0; for (int j = 0; j < QK4_NL/2; ++j) { @@ -11329,7 +11782,6 @@ void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const voi sumf += d * (sumi1 + sumi2); } *s = sumf; -#endif } void lm_ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { @@ -11425,8 +11877,57 @@ void lm_ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const voi *s = hsum_float_8(accum); +#elif defined __AVX__ + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + + __m256 accum = _mm256_setzero_ps(); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16; + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16; + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); + const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); + const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); + const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); + const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); + const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); + const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); + const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); + const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; + sh >>= 4; + const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1)); + const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1)); + const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2)); + const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2)); + sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0); + sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1); + sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0); + sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1); + } + __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0); + __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1); + accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), + _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum); + } + + *s = hsum_float_8(accum); + #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); const vector unsigned char v4 = vec_splats((unsigned char)0x4); vector float vsumf0 = vec_splats(0.0f); @@ -11442,14 +11943,10 @@ void lm_ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const voi vector float vyd = vec_splats(y[ibl].d); vector float vd = vec_mul(vxd, vyd); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; uint16_t h = x[ibl].scales_h; @@ -11494,21 +11991,12 @@ void lm_ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const voi vector signed short vscales01 = vec_splats((int16_t)ls0); vector signed short vscales23 = vec_splats((int16_t)ls1); - vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0); - vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1); - vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2); - vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3); - vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4); - vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5); - vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6); - vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7); + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); } - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); @@ -12880,10 +13368,10 @@ size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq3_xxs * restrict y = vy; - quantize_row_iq3_xxs_reference(x, y, k); + quantize_row_iq3_xxs_ref(x, y, k); } -void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) { +void quantize_row_iq3_xxs_ref(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_row_iq3_xxs_impl(256, x, y, k, NULL); } @@ -13096,10 +13584,10 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t n void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq3_s * restrict y = vy; - quantize_row_iq3_s_reference(x, y, k); + quantize_row_iq3_s_ref(x, y, k); } -void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int64_t k) { +void quantize_row_iq3_s_ref(const float * restrict x, block_iq3_s * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq3_s(x, y, 1, k, NULL); } @@ -13111,7 +13599,7 @@ static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const u const float * restrict xval, const float * restrict weight, float * scale, int8_t * restrict L, int ngrid) { int num_neighbors = neighbours[0]; LM_GGML_ASSERT(num_neighbors > 0); - float best_score = 0; + float best_score = -FLT_MAX; int grid_index = -1; for (int j = 1; j <= num_neighbors; ++j) { const int8_t * pg = (const int8_t *)(grid + neighbours[j]); @@ -13309,7 +13797,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy sumw[j+1] = sumw[j] + weight[i]; } } - float best_score = 0, scale = max; + float best_score = -FLT_MIN, scale = max; int besti1 = -1, besti2 = -1, best_shift = 0; for (int i1 = 0; i1 <= block_size; ++i1) { for (int i2 = i1; i2 <= block_size; ++i2) { @@ -13485,7 +13973,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy idx[2*j] = j; } qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); - float best_score = 0, scale = max; + float best_score = -FLT_MIN, scale = max; int besti1 = -1, besti2 = -1, best_k = -1; // 0: +, + // 1: +, - @@ -13837,7 +14325,7 @@ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k } } -void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { +void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { assert(k % QK4_NL == 0); quantize_row_iq4_nl(x, y, k); } @@ -13865,10 +14353,10 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq4_xs * restrict y = vy; - quantize_row_iq4_xs_reference(x, y, k); + quantize_row_iq4_xs_ref(x, y, k); } -void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int64_t k) { +void quantize_row_iq4_xs_ref(const float * restrict x, block_iq4_xs * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq4_xs(x, y, 1, k, NULL); } @@ -14055,7 +14543,7 @@ size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t n return nrow * nblock * sizeof(block_iq2_s); } -void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int64_t k) { +void quantize_row_iq2_s_ref(const float * restrict x, block_iq2_s * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq2_s(x, y, 1, k, NULL); } @@ -14063,7 +14551,7 @@ void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restri void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq2_s * restrict y = vy; - quantize_row_iq2_s_reference(x, y, k); + quantize_row_iq2_s_ref(x, y, k); } static bool validate_float(float f, size_t i) { @@ -14118,6 +14606,16 @@ static bool validate_fp16(lm_ggml_fp16_t f, size_t i) { } \ } +#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \ + const type * q = (const type *) (data); \ + for (size_t i = 0; i < (nb); ++i) { \ + for (size_t j = 0; j < (nr); ++j) { \ + if (!validate_fp16(q[i].d[j], i)) { \ + return false; \ + } \ + } \ + } + bool lm_ggml_validate_row_data(enum lm_ggml_type type, const void * data, size_t nbytes) { if (type < 0 || type >= LM_GGML_TYPE_COUNT) { fprintf(stderr, "%s: invalid type %d\n", __func__, type); @@ -14335,6 +14833,16 @@ bool lm_ggml_validate_row_data(enum lm_ggml_type type, const void * data, size_t { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; + case LM_GGML_TYPE_Q4_0_4_4: + case LM_GGML_TYPE_Q4_0_4_8: + { + VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4); + } break; + case LM_GGML_TYPE_Q4_0_8_8: + { + VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8); + } break; + case LM_GGML_TYPE_I8: case LM_GGML_TYPE_I16: case LM_GGML_TYPE_I32: diff --git a/cpp/ggml-quants.h b/cpp/ggml-quants.h index 792ab268..d3c3461a 100644 --- a/cpp/ggml-quants.h +++ b/cpp/ggml-quants.h @@ -12,25 +12,25 @@ extern "C" { #endif // Quantization -void quantize_row_q4_0_reference(const float * LM_GGML_RESTRICT x, block_q4_0 * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q4_1_reference(const float * LM_GGML_RESTRICT x, block_q4_1 * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q5_0_reference(const float * LM_GGML_RESTRICT x, block_q5_0 * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q5_1_reference(const float * LM_GGML_RESTRICT x, block_q5_1 * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q8_0_reference(const float * LM_GGML_RESTRICT x, block_q8_0 * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q8_1_reference(const float * LM_GGML_RESTRICT x, block_q8_1 * LM_GGML_RESTRICT y, int64_t k); - -void quantize_row_q2_K_reference(const float * LM_GGML_RESTRICT x, block_q2_K * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q3_K_reference(const float * LM_GGML_RESTRICT x, block_q3_K * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q4_K_reference(const float * LM_GGML_RESTRICT x, block_q4_K * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q5_K_reference(const float * LM_GGML_RESTRICT x, block_q5_K * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q6_K_reference(const float * LM_GGML_RESTRICT x, block_q6_K * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q8_K_reference(const float * LM_GGML_RESTRICT x, block_q8_K * LM_GGML_RESTRICT y, int64_t k); - -void quantize_row_iq3_xxs_reference(const float * LM_GGML_RESTRICT x, block_iq3_xxs * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_nl_reference (const float * LM_GGML_RESTRICT x, block_iq4_nl * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_xs_reference (const float * LM_GGML_RESTRICT x, block_iq4_xs * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_iq3_s_reference (const float * LM_GGML_RESTRICT x, block_iq3_s * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_iq2_s_reference (const float * LM_GGML_RESTRICT x, block_iq2_s * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q4_0_ref(const float * LM_GGML_RESTRICT x, block_q4_0 * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q4_1_ref(const float * LM_GGML_RESTRICT x, block_q4_1 * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q5_0_ref(const float * LM_GGML_RESTRICT x, block_q5_0 * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q5_1_ref(const float * LM_GGML_RESTRICT x, block_q5_1 * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0_ref(const float * LM_GGML_RESTRICT x, block_q8_0 * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q8_1_ref(const float * LM_GGML_RESTRICT x, block_q8_1 * LM_GGML_RESTRICT y, int64_t k); + +void quantize_row_q2_K_ref(const float * LM_GGML_RESTRICT x, block_q2_K * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q3_K_ref(const float * LM_GGML_RESTRICT x, block_q3_K * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q4_K_ref(const float * LM_GGML_RESTRICT x, block_q4_K * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q5_K_ref(const float * LM_GGML_RESTRICT x, block_q5_K * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q6_K_ref(const float * LM_GGML_RESTRICT x, block_q6_K * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K_ref(const float * LM_GGML_RESTRICT x, block_q8_K * LM_GGML_RESTRICT y, int64_t k); + +void quantize_row_iq3_xxs_ref(const float * LM_GGML_RESTRICT x, block_iq3_xxs * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_nl_ref (const float * LM_GGML_RESTRICT x, block_iq4_nl * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xs_ref (const float * LM_GGML_RESTRICT x, block_iq4_xs * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_s_ref (const float * LM_GGML_RESTRICT x, block_iq3_s * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_s_ref (const float * LM_GGML_RESTRICT x, block_iq2_s * LM_GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); @@ -130,4 +130,3 @@ void iq3xs_free_impl(int grid_size); #ifdef __cplusplus } #endif - diff --git a/cpp/ggml.c b/cpp/ggml.c index d9eb5ba0..acd23144 100644 --- a/cpp/ggml.c +++ b/cpp/ggml.c @@ -4,6 +4,7 @@ #include "ggml-impl.h" #include "ggml-quants.h" #include "ggml.h" +#include "ggml-aarch64.h" #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW @@ -28,16 +29,20 @@ #include #endif +#ifdef LM_GGML_USE_OPENMP +#include +#endif + #ifdef LM_GGML_USE_METAL #include #endif -#ifdef __ARM_FEATURE_MATMUL_INT8 +#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8) #undef LM_GGML_USE_LLAMAFILE #endif #ifdef LM_GGML_USE_LLAMAFILE -#include "sgemm.h" +#include #endif #if defined(_MSC_VER) @@ -60,6 +65,9 @@ typedef volatile LONG atomic_int; typedef atomic_int atomic_bool; +typedef atomic_int atomic_flag; + +#define ATOMIC_FLAG_INIT 0 static void atomic_store(atomic_int * ptr, LONG val) { InterlockedExchange(ptr, val); @@ -73,6 +81,12 @@ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) { static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) { return atomic_fetch_add(ptr, -(dec)); } +static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) { + return InterlockedExchange(ptr, 1); +} +static void atomic_flag_clear(atomic_flag * ptr) { + InterlockedExchange(ptr, 0); +} typedef HANDLE pthread_t; @@ -161,7 +175,6 @@ void lm_ggml_print_backtrace(void) { } #endif -/*#define LM_GGML_PERF*/ #define LM_GGML_DEBUG 0 #define LM_GGML_GELU_FP16 #define LM_GGML_GELU_QUICK_FP16 @@ -279,21 +292,10 @@ inline static void * lm_ggml_calloc(size_t num, size_t size) { #define LM_GGML_FREE(ptr) free(ptr) #define UNUSED LM_GGML_UNUSED -#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0) +#define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0) #if defined(LM_GGML_USE_ACCELERATE) #include -#if defined(LM_GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions -#include "ggml-opencl.h" -#endif -#elif defined(LM_GGML_USE_OPENBLAS) -#if defined(LM_GGML_BLAS_USE_MKL) -#include -#else -#include -#endif -#elif defined(LM_GGML_USE_CLBLAST) -#include "ggml-opencl.h" #endif // floating point type used to accumulate sums @@ -471,18 +473,6 @@ int64_t lm_ggml_cycles_per_ms(void) { return CLOCKS_PER_SEC/1000; } -#ifdef LM_GGML_PERF -#define lm_ggml_perf_time_ms() lm_ggml_time_ms() -#define lm_ggml_perf_time_us() lm_ggml_time_us() -#define lm_ggml_perf_cycles() lm_ggml_cycles() -#define lm_ggml_perf_cycles_per_ms() lm_ggml_cycles_per_ms() -#else -#define lm_ggml_perf_time_ms() 0 -#define lm_ggml_perf_time_us() 0 -#define lm_ggml_perf_cycles() 0 -#define lm_ggml_perf_cycles_per_ms() 0 -#endif - // // cross-platform UTF-8 file paths // @@ -602,7 +592,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = false, .to_float = (lm_ggml_to_float_t) lm_ggml_fp16_to_fp32_row, .from_float = (lm_ggml_from_float_t) lm_ggml_fp32_to_fp16_row, - .from_float_reference = (lm_ggml_from_float_t) lm_ggml_fp32_to_fp16_row, + .from_float_ref = (lm_ggml_from_float_t) lm_ggml_fp32_to_fp16_row, .vec_dot = (lm_ggml_vec_dot_t) lm_ggml_vec_dot_f16, .vec_dot_type = LM_GGML_TYPE_F16, .nrows = 1, @@ -614,7 +604,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q4_0, .from_float = quantize_row_q4_0, - .from_float_reference = (lm_ggml_from_float_t) quantize_row_q4_0_reference, + .from_float_ref = (lm_ggml_from_float_t) quantize_row_q4_0_ref, .vec_dot = lm_ggml_vec_dot_q4_0_q8_0, .vec_dot_type = LM_GGML_TYPE_Q8_0, #if defined (__ARM_FEATURE_MATMUL_INT8) @@ -630,7 +620,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q4_1, .from_float = quantize_row_q4_1, - .from_float_reference = (lm_ggml_from_float_t) quantize_row_q4_1_reference, + .from_float_ref = (lm_ggml_from_float_t) quantize_row_q4_1_ref, .vec_dot = lm_ggml_vec_dot_q4_1_q8_1, .vec_dot_type = LM_GGML_TYPE_Q8_1, #if defined (__ARM_FEATURE_MATMUL_INT8) @@ -646,7 +636,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = false, .to_float = NULL, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = NULL, .vec_dot_type = LM_GGML_TYPE_COUNT, .nrows = 1, @@ -658,7 +648,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = false, .to_float = NULL, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = NULL, .vec_dot_type = LM_GGML_TYPE_COUNT, .nrows = 1, @@ -670,7 +660,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q5_0, .from_float = quantize_row_q5_0, - .from_float_reference = (lm_ggml_from_float_t) quantize_row_q5_0_reference, + .from_float_ref = (lm_ggml_from_float_t) quantize_row_q5_0_ref, .vec_dot = lm_ggml_vec_dot_q5_0_q8_0, .vec_dot_type = LM_GGML_TYPE_Q8_0, .nrows = 1, @@ -682,7 +672,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q5_1, .from_float = quantize_row_q5_1, - .from_float_reference = (lm_ggml_from_float_t) quantize_row_q5_1_reference, + .from_float_ref = (lm_ggml_from_float_t) quantize_row_q5_1_ref, .vec_dot = lm_ggml_vec_dot_q5_1_q8_1, .vec_dot_type = LM_GGML_TYPE_Q8_1, .nrows = 1, @@ -694,7 +684,8 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q8_0, .from_float = quantize_row_q8_0, - .from_float_reference = (lm_ggml_from_float_t) quantize_row_q8_0_reference, + .from_float_ref = (lm_ggml_from_float_t) quantize_row_q8_0_ref, + .from_float_to_mat = quantize_mat_q8_0, .vec_dot = lm_ggml_vec_dot_q8_0_q8_0, .vec_dot_type = LM_GGML_TYPE_Q8_0, #if defined (__ARM_FEATURE_MATMUL_INT8) @@ -709,7 +700,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_q8_1), .is_quantized = true, .from_float = quantize_row_q8_1, - .from_float_reference = (lm_ggml_from_float_t) quantize_row_q8_1_reference, + .from_float_ref = (lm_ggml_from_float_t) quantize_row_q8_1_ref, .vec_dot_type = LM_GGML_TYPE_Q8_1, .nrows = 1, }, @@ -720,7 +711,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q2_K, .from_float = quantize_row_q2_K, - .from_float_reference = (lm_ggml_from_float_t) quantize_row_q2_K_reference, + .from_float_ref = (lm_ggml_from_float_t) quantize_row_q2_K_ref, .vec_dot = lm_ggml_vec_dot_q2_K_q8_K, .vec_dot_type = LM_GGML_TYPE_Q8_K, .nrows = 1, @@ -732,7 +723,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q3_K, .from_float = quantize_row_q3_K, - .from_float_reference = (lm_ggml_from_float_t) quantize_row_q3_K_reference, + .from_float_ref = (lm_ggml_from_float_t) quantize_row_q3_K_ref, .vec_dot = lm_ggml_vec_dot_q3_K_q8_K, .vec_dot_type = LM_GGML_TYPE_Q8_K, .nrows = 1, @@ -744,7 +735,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q4_K, .from_float = quantize_row_q4_K, - .from_float_reference = (lm_ggml_from_float_t) quantize_row_q4_K_reference, + .from_float_ref = (lm_ggml_from_float_t) quantize_row_q4_K_ref, .vec_dot = lm_ggml_vec_dot_q4_K_q8_K, .vec_dot_type = LM_GGML_TYPE_Q8_K, .nrows = 1, @@ -756,7 +747,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q5_K, .from_float = quantize_row_q5_K, - .from_float_reference = (lm_ggml_from_float_t) quantize_row_q5_K_reference, + .from_float_ref = (lm_ggml_from_float_t) quantize_row_q5_K_ref, .vec_dot = lm_ggml_vec_dot_q5_K_q8_K, .vec_dot_type = LM_GGML_TYPE_Q8_K, .nrows = 1, @@ -768,7 +759,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q6_K, .from_float = quantize_row_q6_K, - .from_float_reference = (lm_ggml_from_float_t) quantize_row_q6_K_reference, + .from_float_ref = (lm_ggml_from_float_t) quantize_row_q6_K_ref, .vec_dot = lm_ggml_vec_dot_q6_K_q8_K, .vec_dot_type = LM_GGML_TYPE_Q8_K, .nrows = 1, @@ -780,7 +771,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq2_xxs, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = lm_ggml_vec_dot_iq2_xxs_q8_K, .vec_dot_type = LM_GGML_TYPE_Q8_K, .nrows = 1, @@ -792,7 +783,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq2_xs, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = lm_ggml_vec_dot_iq2_xs_q8_K, .vec_dot_type = LM_GGML_TYPE_Q8_K, .nrows = 1, @@ -804,7 +795,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq3_xxs, .from_float = quantize_row_iq3_xxs, - .from_float_reference = (lm_ggml_from_float_t)quantize_row_iq3_xxs_reference, + .from_float_ref = (lm_ggml_from_float_t)quantize_row_iq3_xxs_ref, .vec_dot = lm_ggml_vec_dot_iq3_xxs_q8_K, .vec_dot_type = LM_GGML_TYPE_Q8_K, .nrows = 1, @@ -816,7 +807,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq3_s, .from_float = quantize_row_iq3_s, - .from_float_reference = (lm_ggml_from_float_t)quantize_row_iq3_s_reference, + .from_float_ref = (lm_ggml_from_float_t)quantize_row_iq3_s_ref, .vec_dot = lm_ggml_vec_dot_iq3_s_q8_K, .vec_dot_type = LM_GGML_TYPE_Q8_K, .nrows = 1, @@ -828,7 +819,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq2_s, .from_float = quantize_row_iq2_s, - .from_float_reference = (lm_ggml_from_float_t)quantize_row_iq2_s_reference, + .from_float_ref = (lm_ggml_from_float_t)quantize_row_iq2_s_ref, .vec_dot = lm_ggml_vec_dot_iq2_s_q8_K, .vec_dot_type = LM_GGML_TYPE_Q8_K, .nrows = 1, @@ -840,7 +831,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq1_s, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = lm_ggml_vec_dot_iq1_s_q8_K, .vec_dot_type = LM_GGML_TYPE_Q8_K, .nrows = 1, @@ -852,7 +843,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq1_m, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = lm_ggml_vec_dot_iq1_m_q8_K, .vec_dot_type = LM_GGML_TYPE_Q8_K, .nrows = 1, @@ -864,7 +855,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq4_nl, .from_float = quantize_row_iq4_nl, - .from_float_reference = (lm_ggml_from_float_t)quantize_row_iq4_nl_reference, + .from_float_ref = (lm_ggml_from_float_t)quantize_row_iq4_nl_ref, .vec_dot = lm_ggml_vec_dot_iq4_nl_q8_0, .vec_dot_type = LM_GGML_TYPE_Q8_0, .nrows = 1, @@ -876,7 +867,7 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq4_xs, .from_float = quantize_row_iq4_xs, - .from_float_reference = (lm_ggml_from_float_t)quantize_row_iq4_xs_reference, + .from_float_ref = (lm_ggml_from_float_t)quantize_row_iq4_xs_ref, .vec_dot = lm_ggml_vec_dot_iq4_xs_q8_K, .vec_dot_type = LM_GGML_TYPE_Q8_K, .nrows = 1, @@ -895,10 +886,58 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = { .is_quantized = false, .to_float = (lm_ggml_to_float_t) lm_ggml_bf16_to_fp32_row, .from_float = (lm_ggml_from_float_t) lm_ggml_fp32_to_bf16_row, - .from_float_reference = (lm_ggml_from_float_t) lm_ggml_fp32_to_bf16_row, + .from_float_ref = (lm_ggml_from_float_t) lm_ggml_fp32_to_bf16_row, .vec_dot = (lm_ggml_vec_dot_t) lm_ggml_vec_dot_bf16, .vec_dot_type = LM_GGML_TYPE_BF16, .nrows = 1, + }, + [LM_GGML_TYPE_Q4_0_4_4] = { + .type_name = "q4_0_4x4", + .blck_size = QK4_0, + .blck_size_interleave = 4, + .type_size = sizeof(block_q4_0), + .is_quantized = true, + .to_float = NULL, + .from_float = NULL, + .from_float_ref = NULL, + .vec_dot = NULL, + .vec_dot_type = LM_GGML_TYPE_Q8_0, + .nrows = 1, + .ncols = 4, + .gemv = lm_ggml_gemv_q4_0_4x4_q8_0, + .gemm = lm_ggml_gemm_q4_0_4x4_q8_0, + }, + [LM_GGML_TYPE_Q4_0_4_8] = { + .type_name = "q4_0_4x8", + .blck_size = QK4_0, + .blck_size_interleave = 8, + .type_size = sizeof(block_q4_0), + .is_quantized = true, + .to_float = NULL, + .from_float = NULL, + .from_float_ref = NULL, + .vec_dot = NULL, + .vec_dot_type = LM_GGML_TYPE_Q8_0, + .nrows = 1, + .ncols = 4, + .gemv = lm_ggml_gemv_q4_0_4x8_q8_0, + .gemm = lm_ggml_gemm_q4_0_4x8_q8_0, + }, + [LM_GGML_TYPE_Q4_0_8_8] = { + .type_name = "q4_0_8x8", + .blck_size = QK4_0, + .blck_size_interleave = 8, + .type_size = sizeof(block_q4_0), + .is_quantized = true, + .to_float = NULL, + .from_float = NULL, + .from_float_ref = NULL, + .vec_dot = NULL, + .vec_dot_type = LM_GGML_TYPE_Q8_0, + .nrows = 1, + .ncols = 8, + .gemv = lm_ggml_gemv_q4_0_8x8_q8_0, + .gemm = lm_ggml_gemm_q4_0_8x8_q8_0, } }; @@ -1567,11 +1606,11 @@ do { \ // F16 arithmetic is not supported by AVX, so we use F32 instead -#define LM_GGML_F32Cx8 __m256 +#define LM_GGML_F32Cx8 __m256 #define LM_GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0) #define LM_GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x)) -static inline __m256 __lasx_f32cx8_load(lm_ggml_fp16_t *x) { +static inline __m256 __lasx_f32cx8_load(const lm_ggml_fp16_t * x) { float tmp[8]; for (int i = 0; i < 8; i++) { @@ -1580,13 +1619,14 @@ static inline __m256 __lasx_f32cx8_load(lm_ggml_fp16_t *x) { return (__m256)__lasx_xvld(tmp, 0); } -static inline void __lasx_f32cx8_store(lm_ggml_fp16_t *x, __m256 y) { +static inline void __lasx_f32cx8_store(lm_ggml_fp16_t * x, __m256 y) { float arr[8]; __lasx_xvst(y, arr, 0); - for (int i = 0; i < 8; i++) + for (int i = 0; i < 8; i++) { x[i] = LM_GGML_FP32_TO_FP16(arr[i]); + } } #define LM_GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x) #define LM_GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y) @@ -1662,7 +1702,7 @@ static inline void __lasx_f32cx8_store(lm_ggml_fp16_t *x, __m256 y) { #define LM_GGML_F16_STEP 32 #define LM_GGML_F16_EPR 4 -static inline __m128 __lsx_f16x4_load(lm_ggml_fp16_t *x) { +static inline __m128 __lsx_f16x4_load(const lm_ggml_fp16_t * x) { float tmp[4]; tmp[0] = LM_GGML_FP16_TO_FP32(x[0]); @@ -1673,7 +1713,7 @@ static inline __m128 __lsx_f16x4_load(lm_ggml_fp16_t *x) { return __lsx_vld(tmp, 0); } -static inline void __lsx_f16x4_store(lm_ggml_fp16_t *x, __m128 y) { +static inline void __lsx_f16x4_store(lm_ggml_fp16_t * x, __m128 y) { float arr[4]; __lsx_vst(y, arr, 0); @@ -1726,8 +1766,8 @@ struct lm_ggml_context { int n_objects; - struct lm_ggml_object* objects_begin; - struct lm_ggml_object* objects_end; + struct lm_ggml_object * objects_begin; + struct lm_ggml_object * objects_end; struct lm_ggml_scratch scratch; struct lm_ggml_scratch scratch_save; @@ -1740,30 +1780,38 @@ struct lm_ggml_context_container { }; struct lm_ggml_compute_state_shared { - const struct lm_ggml_cgraph* cgraph; - const struct lm_ggml_cplan* cplan; + const struct lm_ggml_cgraph * cgraph; + const struct lm_ggml_cplan * cplan; - int64_t perf_node_start_cycles; - int64_t perf_node_start_time_us; - - const int n_threads; + int n_threads; // synchronization primitives - atomic_int n_active; // num active threads - atomic_int node_n; // active graph node - atomic_int node_task; // active graph node task phase + atomic_int n_barrier; + atomic_int n_barrier_passed; lm_ggml_abort_callback abort_callback; // abort lm_ggml_graph_compute when true - void* abort_callback_data; + void * abort_callback_data; + + atomic_int current_chunk; // currently processing chunk during mul_mat, shared between all the threads - atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. + enum lm_ggml_status ec; }; struct lm_ggml_compute_state { lm_ggml_thread_t thrd; int ith; - struct lm_ggml_compute_state_shared* shared; - enum lm_ggml_status ec; + struct lm_ggml_compute_state_shared * shared; +}; + +struct lm_ggml_compute_params { + // ith = thread index, nth = number of threads + int ith, nth; + + // work buffer for all threads + size_t wsize; + void * wdata; + + struct lm_ggml_compute_state_shared * shared; }; // @@ -2257,6 +2305,11 @@ inline static float lm_ggml_silu_f32(float x) { return x/(1.0f + expf(-x)); } +#if __FINITE_MATH_ONLY__ +#error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix" +#error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461" +#endif + #if defined(__ARM_NEON) && defined(__aarch64__) // adapted from arm limited optimized routine @@ -2306,32 +2359,27 @@ inline static __m512 lm_ggml_v_expf(__m512 x) { const __m512 r = _mm512_set1_ps(0x1.8p23f); const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r); const __m512 n = _mm512_sub_ps(z, r); - const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), - _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); - const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23); - const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1)))); - const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ); - const __m512 u = _mm512_mul_ps(b, b); - const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, - _mm512_set1_ps(0x1.573e2ep-5f)), u, - _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, - _mm512_set1_ps(0x1.fffdb6p-2f))), - u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b)); - if (_mm512_kortestz(c, c)) - return _mm512_fmadd_ps(j, k, k); - const __m512i g = _mm512_and_si512( - _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)), - _mm512_set1_epi32(0x82000000u)); - const __m512 s1 = - _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u))); - const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g)); + const __m512 b = + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); const __mmask16 d = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); - return _mm512_mask_blend_ps( - d, _mm512_mask_blend_ps( - c, _mm512_fmadd_ps(k, j, k), - _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)), - _mm512_mul_ps(s1, s1)); + const __m512 u = _mm512_mul_ps(b, b); + const __m512 j = _mm512_fmadd_ps( + _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, + _mm512_set1_ps(0x1.573e2ep-5f)), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, + _mm512_set1_ps(0x1.fffdb6p-2f))), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F))); + const __m512 res = _mm512_scalef_ps(j, n); + if (_mm512_kortestz(d, d)) + return res; + const __m512 zero = _mm512_setzero_ps(); + const __m512 alt = _mm512_mask_blend_ps( + _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero); + return _mm512_mask_blend_ps(d, res, alt); } // computes silu x/(1+exp(-x)) in single precision vector @@ -2811,42 +2859,6 @@ static_assert(LM_GGML_UNARY_OP_COUNT == 13, "LM_GGML_UNARY_OP_COUNT != 13"); static_assert(sizeof(struct lm_ggml_object)%LM_GGML_MEM_ALIGN == 0, "lm_ggml_object size must be a multiple of LM_GGML_MEM_ALIGN"); static_assert(sizeof(struct lm_ggml_tensor)%LM_GGML_MEM_ALIGN == 0, "lm_ggml_tensor size must be a multiple of LM_GGML_MEM_ALIGN"); -// WARN: -// Mis-configuration can lead to problem that's hard to reason about: -// * At best it crash or talks nosense. -// * At worst it talks slightly difference but hard to perceive. -// -// An op has to enable INIT or FINALIZE when any of it's branch needs that pass. -// Take care about compile options (e.g., LM_GGML_USE_xxx). -static bool LM_GGML_OP_HAS_INIT [LM_GGML_OP_COUNT] = { 0 }; -static bool LM_GGML_OP_HAS_FINALIZE[LM_GGML_OP_COUNT] = { 0 }; - -static void lm_ggml_setup_op_has_task_pass(void) { - { // INIT - bool * p = LM_GGML_OP_HAS_INIT; - - p[LM_GGML_OP_ACC ] = true; - p[LM_GGML_OP_MUL_MAT ] = true; - p[LM_GGML_OP_MUL_MAT_ID ] = true; - p[LM_GGML_OP_OUT_PROD ] = true; - p[LM_GGML_OP_SET ] = true; - p[LM_GGML_OP_GET_ROWS_BACK ] = true; - p[LM_GGML_OP_DIAG_MASK_INF ] = true; - p[LM_GGML_OP_DIAG_MASK_ZERO ] = true; - p[LM_GGML_OP_CONV_TRANSPOSE_1D ] = true; - p[LM_GGML_OP_CONV_TRANSPOSE_2D ] = true; - p[LM_GGML_OP_FLASH_ATTN_BACK ] = true; - p[LM_GGML_OP_CROSS_ENTROPY_LOSS ] = true; - p[LM_GGML_OP_ADD_REL_POS ] = true; - } - - { // FINALIZE - bool * p = LM_GGML_OP_HAS_FINALIZE; - - p[LM_GGML_OP_CROSS_ENTROPY_LOSS ] = true; - } -} - // // NUMA support // @@ -2883,24 +2895,62 @@ struct lm_ggml_state { // global state static struct lm_ggml_state g_state; -static atomic_int g_state_barrier = 0; +static atomic_flag g_state_critical = ATOMIC_FLAG_INIT; -// barrier via spin lock +// critical section via spin lock inline static void lm_ggml_critical_section_start(void) { - int processing = atomic_fetch_add(&g_state_barrier, 1); + while (atomic_flag_test_and_set(&g_state_critical)) { + // spin + sched_yield(); + } +} + +#ifdef LM_GGML_USE_OPENMP +static void lm_ggml_barrier(struct lm_ggml_compute_state_shared * shared) { + if (shared->n_threads == 1) { + return; + } + + #pragma omp barrier +} +#else +static void lm_ggml_barrier(struct lm_ggml_compute_state_shared * shared) { + if (shared->n_threads == 1) { + return; + } + + atomic_int * n_barrier = &shared->n_barrier; + atomic_int * n_barrier_passed = &shared->n_barrier_passed; - while (processing > 0) { - // wait for other threads to finish - atomic_fetch_sub(&g_state_barrier, 1); - sched_yield(); // TODO: reconsider this - processing = atomic_fetch_add(&g_state_barrier, 1); + int n_threads = shared->n_threads; + int passed_old = atomic_load(n_barrier_passed); + + if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) { + // last thread + atomic_store(n_barrier, 0); + atomic_fetch_add(n_barrier_passed, 1); + } else { + // wait for other threads + const int n_spin_before_sleep = 100000; + while (true) { + for (int i = 0; i < n_spin_before_sleep; i++) { + if (atomic_load(n_barrier_passed) != passed_old) { + return; + } + #if defined(__SSE3__) + _mm_pause(); + #endif + } + sched_yield(); + } } } +#endif // TODO: make this somehow automatically executed // some sort of "sentry" mechanism inline static void lm_ggml_critical_section_end(void) { - atomic_fetch_sub(&g_state_barrier, 1); + atomic_flag_clear(&g_state_critical); } #if defined(__gnu_linux__) @@ -3001,7 +3051,7 @@ void lm_ggml_numa_init(enum lm_ggml_numa_strategy numa_flag) { } } #else - LM_GGML_UNUSED(numa_flag); + UNUSED(numa_flag); // TODO #endif } @@ -3065,7 +3115,7 @@ size_t lm_ggml_nbytes_pad(const struct lm_ggml_tensor * tensor) { return LM_GGML_PAD(lm_ggml_nbytes(tensor), LM_GGML_MEM_ALIGN); } -LM_GGML_CALL int lm_ggml_blck_size(enum lm_ggml_type type) { +LM_GGML_CALL int64_t lm_ggml_blck_size(enum lm_ggml_type type) { return type_traits[type].blck_size; } @@ -3107,9 +3157,7 @@ LM_GGML_CALL const char * lm_ggml_op_desc(const struct lm_ggml_tensor * t) { enum lm_ggml_unary_op uop = lm_ggml_get_unary_op(t); return lm_ggml_unary_op_name(uop); } - else { - return lm_ggml_op_name(t->op); - } + return lm_ggml_op_name(t->op); } LM_GGML_CALL size_t lm_ggml_element_size(const struct lm_ggml_tensor * tensor) { @@ -3189,6 +3237,9 @@ enum lm_ggml_type lm_ggml_ftype_to_lm_ggml_type(enum lm_ggml_ftype ftype) { case LM_GGML_FTYPE_MOSTLY_IQ4_XS: wtype = LM_GGML_TYPE_IQ4_XS; break; case LM_GGML_FTYPE_MOSTLY_IQ3_S: wtype = LM_GGML_TYPE_IQ3_S; break; case LM_GGML_FTYPE_MOSTLY_IQ2_S: wtype = LM_GGML_TYPE_IQ2_S; break; + case LM_GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = LM_GGML_TYPE_Q4_0_4_4; break; + case LM_GGML_FTYPE_MOSTLY_Q4_0_4_8: wtype = LM_GGML_TYPE_Q4_0_4_8; break; + case LM_GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = LM_GGML_TYPE_Q4_0_8_8; break; case LM_GGML_FTYPE_UNKNOWN: wtype = LM_GGML_TYPE_COUNT; break; case LM_GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = LM_GGML_TYPE_COUNT; break; } @@ -3206,23 +3257,42 @@ LM_GGML_CALL bool lm_ggml_is_transposed(const struct lm_ggml_tensor * tensor) { return tensor->nb[0] > tensor->nb[1]; } +static bool lm_ggml_is_contiguous_n(const struct lm_ggml_tensor * tensor, int n) { + size_t next_nb = lm_ggml_type_size(tensor->type); + if (tensor->ne[0] != lm_ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) { + return false; + } + next_nb *= tensor->ne[0]/lm_ggml_blck_size(tensor->type); + for (int i = 1; i < LM_GGML_MAX_DIMS; i++) { + if (tensor->ne[i] != 1) { + if (i > n) { + if (tensor->nb[i] != next_nb) { + return false; + } + next_nb *= tensor->ne[i]; + } else { + // this dimension does not need to be contiguous + next_nb = tensor->ne[i]*tensor->nb[i]; + } + } + } + return true; +} + LM_GGML_CALL bool lm_ggml_is_contiguous(const struct lm_ggml_tensor * tensor) { - static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); + return lm_ggml_is_contiguous_0(tensor); +} - return - tensor->nb[0] == lm_ggml_type_size(tensor->type) && - tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/lm_ggml_blck_size(tensor->type) && - tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +LM_GGML_CALL bool lm_ggml_is_contiguous_0(const struct lm_ggml_tensor * tensor) { + return lm_ggml_is_contiguous_n(tensor, 0); } -static inline bool lm_ggml_is_contiguous_except_dim_1(const struct lm_ggml_tensor * tensor) { - static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); +LM_GGML_CALL bool lm_ggml_is_contiguous_1(const struct lm_ggml_tensor * tensor) { + return lm_ggml_is_contiguous_n(tensor, 1); +} - return - tensor->nb[0] == lm_ggml_type_size(tensor->type) && - tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +LM_GGML_CALL bool lm_ggml_is_contiguous_2(const struct lm_ggml_tensor * tensor) { + return lm_ggml_is_contiguous_n(tensor, 2); } LM_GGML_CALL bool lm_ggml_is_permuted(const struct lm_ggml_tensor * tensor) { @@ -3254,24 +3324,24 @@ bool lm_ggml_are_same_shape(const struct lm_ggml_tensor * t0, const struct lm_gg static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); return - (t0->ne[0] == t1->ne[0] ) && - (t0->ne[1] == t1->ne[1] ) && - (t0->ne[2] == t1->ne[2] ) && - (t0->ne[3] == t1->ne[3] ); + (t0->ne[0] == t1->ne[0]) && + (t0->ne[1] == t1->ne[1]) && + (t0->ne[2] == t1->ne[2]) && + (t0->ne[3] == t1->ne[3]); } bool lm_ggml_are_same_stride(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) { static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); return - (t0->nb[0] == t1->nb[0] ) && - (t0->nb[1] == t1->nb[1] ) && - (t0->nb[2] == t1->nb[2] ) && - (t0->nb[3] == t1->nb[3] ); + (t0->nb[0] == t1->nb[0]) && + (t0->nb[1] == t1->nb[1]) && + (t0->nb[2] == t1->nb[2]) && + (t0->nb[3] == t1->nb[3]); } // check if t1 can be represented as a repeatition of t0 -static inline bool lm_ggml_can_repeat(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) { +bool lm_ggml_can_repeat(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) { static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); return lm_ggml_is_empty(t0) ? lm_ggml_is_empty(t1) : @@ -3357,12 +3427,6 @@ struct lm_ggml_context * lm_ggml_init(struct lm_ggml_init_params params) { LM_GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); } -#if defined(LM_GGML_USE_CLBLAST) - lm_ggml_cl_init(); -#endif - - lm_ggml_setup_op_has_task_pass(); - is_first_call = false; } @@ -3629,15 +3693,12 @@ static struct lm_ggml_tensor * lm_ggml_new_tensor_impl( /*.flags =*/ 0, /*.grad =*/ NULL, /*.src =*/ { NULL }, - /*.perf_runs =*/ 0, - /*.perf_cycles =*/ 0, - /*.perf_time_us =*/ 0, /*.view_src =*/ view_src, /*.view_offs =*/ view_offs, /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data, /*.name =*/ { 0 }, /*.extra =*/ NULL, - /*.padding =*/ { 0 }, + ///*.padding =*/ { 0 }, }; #ifdef __clang__ @@ -4064,32 +4125,26 @@ float lm_ggml_get_f32_1d(const struct lm_ggml_tensor * tensor, int i) { switch (tensor->type) { case LM_GGML_TYPE_I8: { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); return ((int8_t *)(tensor->data))[i]; } case LM_GGML_TYPE_I16: { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); return ((int16_t *)(tensor->data))[i]; } case LM_GGML_TYPE_I32: { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); return ((int32_t *)(tensor->data))[i]; } case LM_GGML_TYPE_F16: { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(lm_ggml_fp16_t)); return LM_GGML_FP16_TO_FP32(((lm_ggml_fp16_t *)(tensor->data))[i]); } case LM_GGML_TYPE_BF16: { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(lm_ggml_bf16_t)); return LM_GGML_BF16_TO_FP32(((lm_ggml_bf16_t *)(tensor->data))[i]); } case LM_GGML_TYPE_F32: { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(float)); return ((float *)(tensor->data))[i]; } default: @@ -4111,32 +4166,26 @@ void lm_ggml_set_f32_1d(const struct lm_ggml_tensor * tensor, int i, float value switch (tensor->type) { case LM_GGML_TYPE_I8: { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); ((int8_t *)(tensor->data))[i] = value; } break; case LM_GGML_TYPE_I16: { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); ((int16_t *)(tensor->data))[i] = value; } break; case LM_GGML_TYPE_I32: { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); ((int32_t *)(tensor->data))[i] = value; } break; case LM_GGML_TYPE_F16: { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(lm_ggml_fp16_t)); ((lm_ggml_fp16_t *)(tensor->data))[i] = LM_GGML_FP32_TO_FP16(value); } break; case LM_GGML_TYPE_BF16: { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(lm_ggml_bf16_t)); ((lm_ggml_bf16_t *)(tensor->data))[i] = LM_GGML_FP32_TO_BF16(value); } break; case LM_GGML_TYPE_F32: { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(float)); ((float *)(tensor->data))[i] = value; } break; default: @@ -5315,7 +5364,7 @@ void lm_ggml_mul_mat_set_prec( as -> [cols, rows, n_expert] ids -> [n_experts_used, n_tokens] (i32) b -> [cols, n_expert_used, n_tokens] - c -> [cols, n_expert_used, n_tokens] + c -> [rows, n_expert_used, n_tokens] in b, n_experts_used can be broadcasted to match the n_expert_used of ids @@ -6236,16 +6285,13 @@ static struct lm_ggml_tensor * lm_ggml_rope_impl( struct lm_ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow, - float xpos_base, - bool xpos_down, bool inplace) { LM_GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported"); @@ -6266,15 +6312,13 @@ static struct lm_ggml_tensor * lm_ggml_rope_impl( struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx }; + int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; memcpy(params + 5, &freq_base, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float)); - memcpy(params + 11, &xpos_base, sizeof(float)); - memcpy(params + 12, &xpos_down, sizeof(bool)); lm_ggml_set_op_params(result, params, sizeof(params)); result->op = LM_GGML_OP_ROPE; @@ -6291,10 +6335,9 @@ struct lm_ggml_tensor * lm_ggml_rope( struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, int n_dims, - int mode, - int n_ctx) { + int mode) { return lm_ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false + ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false ); } @@ -6303,10 +6346,9 @@ struct lm_ggml_tensor * lm_ggml_rope_inplace( struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, int n_dims, - int mode, - int n_ctx) { + int mode) { return lm_ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true + ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true ); } @@ -6317,8 +6359,7 @@ struct lm_ggml_tensor * lm_ggml_rope_ext( struct lm_ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -6326,8 +6367,8 @@ struct lm_ggml_tensor * lm_ggml_rope_ext( float beta_fast, float beta_slow) { return lm_ggml_rope_impl( - ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false + ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, false ); } @@ -6338,8 +6379,7 @@ struct lm_ggml_tensor * lm_ggml_rope_ext_inplace( struct lm_ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -6347,8 +6387,8 @@ struct lm_ggml_tensor * lm_ggml_rope_ext_inplace( float beta_fast, float beta_slow) { return lm_ggml_rope_impl( - ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true + ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, true ); } @@ -6358,8 +6398,7 @@ struct lm_ggml_tensor * lm_ggml_rope_custom( struct lm_ggml_tensor * b, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -6367,8 +6406,8 @@ struct lm_ggml_tensor * lm_ggml_rope_custom( float beta_fast, float beta_slow) { return lm_ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false + ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, false ); } @@ -6378,8 +6417,7 @@ struct lm_ggml_tensor * lm_ggml_rope_custom_inplace( struct lm_ggml_tensor * b, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -6387,8 +6425,8 @@ struct lm_ggml_tensor * lm_ggml_rope_custom_inplace( float beta_fast, float beta_slow) { return lm_ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true + ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, true ); } @@ -6401,16 +6439,13 @@ struct lm_ggml_tensor * lm_ggml_rope_back( struct lm_ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, - float beta_slow, - float xpos_base, - bool xpos_down) { + float beta_slow) { LM_GGML_ASSERT(lm_ggml_is_vector(b)); LM_GGML_ASSERT(b->type == LM_GGML_TYPE_I32); LM_GGML_ASSERT(a->ne[2] == b->ne[0]); @@ -6426,15 +6461,13 @@ struct lm_ggml_tensor * lm_ggml_rope_back( struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); - int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx }; + int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; memcpy(params + 5, &freq_base, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float)); - memcpy(params + 11, &xpos_base, sizeof(float)); - memcpy(params + 12, &xpos_down, sizeof(bool)); lm_ggml_set_op_params(result, params, sizeof(params)); result->op = LM_GGML_OP_ROPE_BACK; @@ -7345,13 +7378,15 @@ struct lm_ggml_tensor * lm_ggml_add_rel_pos_inplace( return lm_ggml_add_rel_pos_impl(ctx, a, pw, ph, true); } -// gmml_unary +// lm_ggml_unary static struct lm_ggml_tensor * lm_ggml_unary_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, enum lm_ggml_unary_op op, bool inplace) { + LM_GGML_ASSERT(lm_ggml_is_contiguous_1(a)); + bool is_node = false; if (!inplace && (a->grad)) { @@ -7841,10 +7876,6 @@ static void lm_ggml_compute_forward_dup_same_cont( LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0)); LM_GGML_ASSERT(src0->type == dst->type); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - const size_t nb00 = src0->nb[0]; const size_t nb0 = dst->nb[0]; @@ -7873,10 +7904,6 @@ static void lm_ggml_compute_forward_dup_f16( LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_TENSOR_UNARY_OP_LOCALS const int ith = params->ith; // thread index @@ -8146,10 +8173,6 @@ static void lm_ggml_compute_forward_dup_bf16( LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_TENSOR_UNARY_OP_LOCALS const int ith = params->ith; // thread index @@ -8506,10 +8529,6 @@ static void lm_ggml_compute_forward_dup_f32( LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_TENSOR_UNARY_OP_LOCALS const int ith = params->ith; // thread index @@ -8829,10 +8848,6 @@ static void lm_ggml_compute_forward_dup_bytes( LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); LM_GGML_ASSERT(src0->type == dst->type); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - if (lm_ggml_is_contiguous(src0) && lm_ggml_is_contiguous(dst)) { lm_ggml_compute_forward_dup_same_cont(params, dst); return; @@ -9013,24 +9028,9 @@ static void lm_ggml_compute_forward_add_f32( LM_GGML_ASSERT(lm_ggml_can_repeat(src1, src0) && lm_ggml_are_same_shape(src0, dst)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - const int ith = params->ith; const int nth = params->nth; -#ifdef LM_GGML_USE_CLBLAST - if (src1->backend == LM_GGML_BACKEND_TYPE_GPU) { - // TODO: OpenCL kernel support full broadcast - LM_GGML_ASSERT(lm_ggml_can_repeat_rows(src1, src0)); - if (ith == 0) { - lm_ggml_cl_add(src0, src1, dst); - } - return; - } -#endif - const int nr = lm_ggml_nrows(src0); LM_GGML_TENSOR_BINARY_OP_LOCALS @@ -9103,10 +9103,6 @@ static void lm_ggml_compute_forward_add_f16_f32( LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - const int ith = params->ith; const int nth = params->nth; @@ -9182,10 +9178,6 @@ static void lm_ggml_compute_forward_add_bf16_f32( LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - const int ith = params->ith; const int nth = params->nth; @@ -9261,10 +9253,6 @@ static void lm_ggml_compute_forward_add_f16_f16( LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - const int ith = params->ith; const int nth = params->nth; @@ -9317,10 +9305,6 @@ static void lm_ggml_compute_forward_add_bf16_bf16( LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - const int ith = params->ith; const int nth = params->nth; @@ -9373,10 +9357,6 @@ static void lm_ggml_compute_forward_add_q_f32( LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - const int nr = lm_ggml_nrows(src0); LM_GGML_TENSOR_BINARY_OP_LOCALS @@ -9504,6 +9484,9 @@ static void lm_ggml_compute_forward_add( case LM_GGML_TYPE_IQ4_XS: case LM_GGML_TYPE_IQ3_S: case LM_GGML_TYPE_IQ2_S: + case LM_GGML_TYPE_Q4_0_4_4: + case LM_GGML_TYPE_Q4_0_4_8: + case LM_GGML_TYPE_Q4_0_8_8: { lm_ggml_compute_forward_add_q_f32(params, dst); } break; @@ -9526,10 +9509,6 @@ static void lm_ggml_compute_forward_add1_f32( LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - const int ith = params->ith; const int nth = params->nth; @@ -9580,10 +9559,6 @@ static void lm_ggml_compute_forward_add1_f16_f32( LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - // scalar to add const float v = *(float *) src1->data; @@ -9632,10 +9607,6 @@ static void lm_ggml_compute_forward_add1_f16_f16( LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - // scalar to add const float v = LM_GGML_FP16_TO_FP32(*(lm_ggml_fp16_t *) src1->data); @@ -9684,10 +9655,6 @@ static void lm_ggml_compute_forward_add1_q_f32( LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - // scalar to add const float v = *(float *) src1->data; @@ -9753,10 +9720,6 @@ static void lm_ggml_compute_forward_add1_bf16_f32( LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - // scalar to add const float v = *(float *) src1->data; @@ -9805,10 +9768,6 @@ static void lm_ggml_compute_forward_add1_bf16_bf16( LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - // scalar to add const float v = LM_GGML_BF16_TO_FP32(*(lm_ggml_bf16_t *) src1->data); @@ -9903,6 +9862,9 @@ static void lm_ggml_compute_forward_add1( case LM_GGML_TYPE_IQ4_XS: case LM_GGML_TYPE_IQ3_S: case LM_GGML_TYPE_IQ2_S: + case LM_GGML_TYPE_Q4_0_4_4: + case LM_GGML_TYPE_Q4_0_4_8: + case LM_GGML_TYPE_Q4_0_8_8: { lm_ggml_compute_forward_add1_q_f32(params, dst); } break; @@ -9933,20 +9895,16 @@ static void lm_ggml_compute_forward_acc_f32( size_t offset = ((int32_t *) dst->op_params)[3]; bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - if (!inplace && (params->type == LM_GGML_TASK_TYPE_INIT)) { - if (params->ith != 0) { - return; + if (!inplace) { + if (params->ith == 0) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memcpy( + ((char *) dst->data), + ((char *) src0->data), + lm_ggml_nbytes(dst)); } - // memcpy needs to be synchronized across threads to avoid race conditions. - // => do it in INIT phase - memcpy( - ((char *) dst->data), - ((char *) src0->data), - lm_ggml_nbytes(dst)); - } - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; + lm_ggml_barrier(params->shared); } const int ith = params->ith; @@ -10032,6 +9990,9 @@ static void lm_ggml_compute_forward_acc( case LM_GGML_TYPE_IQ4_XS: case LM_GGML_TYPE_IQ3_S: case LM_GGML_TYPE_IQ2_S: + case LM_GGML_TYPE_Q4_0_4_4: + case LM_GGML_TYPE_Q4_0_4_8: + case LM_GGML_TYPE_Q4_0_8_8: default: { LM_GGML_ASSERT(false); @@ -10048,13 +10009,12 @@ static void lm_ggml_compute_forward_sub_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; const struct lm_ggml_tensor * src1 = dst->src[1]; - assert(params->ith == 0); - assert(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); + const int nr = lm_ggml_nrows(src0); LM_GGML_TENSOR_BINARY_OP_LOCALS @@ -10132,23 +10092,9 @@ static void lm_ggml_compute_forward_mul_f32( LM_GGML_ASSERT(lm_ggml_can_repeat(src1, src0) && lm_ggml_are_same_shape(src0, dst)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } const int ith = params->ith; const int nth = params->nth; -#if defined(LM_GGML_USE_CLBLAST) - if (src1->backend == LM_GGML_BACKEND_TYPE_GPU) { - // TODO: OpenCL kernel support full broadcast - LM_GGML_ASSERT(lm_ggml_can_repeat_rows(src1, src0)); - if (ith == 0) { - lm_ggml_cl_mul(src0, src1, dst); - } - return; - } -#endif - const int64_t nr = lm_ggml_nrows(src0); LM_GGML_TENSOR_BINARY_OP_LOCALS @@ -10240,10 +10186,6 @@ static void lm_ggml_compute_forward_div_f32( LM_GGML_ASSERT(lm_ggml_can_repeat(src1, src0) && lm_ggml_are_same_shape(src0, dst)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - const int ith = params->ith; const int nth = params->nth; @@ -10332,13 +10274,12 @@ static void lm_ggml_compute_forward_sqr_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_are_same_shape(src0, dst)); + const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; @@ -10378,13 +10319,12 @@ static void lm_ggml_compute_forward_sqrt_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_are_same_shape(src0, dst)); + const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; @@ -10424,13 +10364,12 @@ static void lm_ggml_compute_forward_log_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - LM_GGML_ASSERT(params->ith == 0); - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; @@ -10470,13 +10409,13 @@ static void lm_ggml_compute_forward_sum_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(lm_ggml_is_scalar(dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_is_scalar(dst)); + + assert(lm_ggml_is_scalar(dst)); assert(src0->nb[0] == sizeof(float)); @@ -10505,13 +10444,12 @@ static void lm_ggml_compute_forward_sum_f16( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(lm_ggml_is_scalar(dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_is_scalar(dst)); + assert(src0->nb[0] == sizeof(lm_ggml_fp16_t)); LM_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) @@ -10539,13 +10477,12 @@ static void lm_ggml_compute_forward_sum_bf16( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(lm_ggml_is_scalar(dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_is_scalar(dst)); + assert(src0->nb[0] == sizeof(lm_ggml_bf16_t)); LM_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) @@ -10601,9 +10538,7 @@ static void lm_ggml_compute_forward_sum_rows_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - LM_GGML_ASSERT(params->ith == 0); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -10656,9 +10591,7 @@ static void lm_ggml_compute_forward_mean_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -10715,9 +10648,7 @@ static void lm_ggml_compute_forward_argmax_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -10765,13 +10696,12 @@ static void lm_ggml_compute_forward_repeat_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - LM_GGML_ASSERT(params->ith == 0); - LM_GGML_ASSERT(lm_ggml_can_repeat(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + LM_GGML_ASSERT(lm_ggml_can_repeat(src0, dst)); + LM_GGML_TENSOR_UNARY_OP_LOCALS // guaranteed to be an integer due to the check in lm_ggml_can_repeat @@ -10810,13 +10740,12 @@ static void lm_ggml_compute_forward_repeat_f16( const struct lm_ggml_tensor * src0 = dst->src[0]; - LM_GGML_ASSERT(params->ith == 0); - LM_GGML_ASSERT(lm_ggml_can_repeat(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + LM_GGML_ASSERT(lm_ggml_can_repeat(src0, dst)); + LM_GGML_TENSOR_UNARY_OP_LOCALS // guaranteed to be an integer due to the check in lm_ggml_can_repeat @@ -10885,13 +10814,12 @@ static void lm_ggml_compute_forward_repeat_back_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - LM_GGML_ASSERT(params->ith == 0); - LM_GGML_ASSERT(lm_ggml_can_repeat(dst, src0)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + LM_GGML_ASSERT(lm_ggml_can_repeat(dst, src0)); + LM_GGML_TENSOR_UNARY_OP_LOCALS // guaranteed to be an integer due to the check in lm_ggml_can_repeat @@ -10965,10 +10893,6 @@ static void lm_ggml_compute_forward_concat_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; const struct lm_ggml_tensor * src1 = dst->src[1]; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); const int ith = params->ith; @@ -11012,7 +10936,7 @@ static void lm_ggml_compute_forward_concat_f32( static void lm_ggml_compute_forward_concat( const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor* dst) { + struct lm_ggml_tensor * dst) { const struct lm_ggml_tensor * src0 = dst->src[0]; @@ -11037,19 +10961,17 @@ static void lm_ggml_compute_forward_abs_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { lm_ggml_vec_abs_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11083,19 +11005,17 @@ static void lm_ggml_compute_forward_sgn_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { lm_ggml_vec_sgn_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11129,19 +11049,17 @@ static void lm_ggml_compute_forward_neg_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { lm_ggml_vec_neg_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11175,19 +11093,17 @@ static void lm_ggml_compute_forward_step_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { lm_ggml_vec_step_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11221,19 +11137,17 @@ static void lm_ggml_compute_forward_tanh_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { lm_ggml_vec_tanh_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11267,19 +11181,17 @@ static void lm_ggml_compute_forward_elu_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { lm_ggml_vec_elu_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11313,19 +11225,17 @@ static void lm_ggml_compute_forward_relu_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { lm_ggml_vec_relu_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11359,19 +11269,17 @@ static void lm_ggml_compute_forward_sigmoid_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { lm_ggml_vec_sigmoid_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11405,13 +11313,9 @@ static void lm_ggml_compute_forward_gelu_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - LM_GGML_ASSERT(lm_ggml_is_contiguous_except_dim_1(src0)); - LM_GGML_ASSERT(lm_ggml_is_contiguous_except_dim_1(dst)); - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); const int ith = params->ith; const int nth = params->nth; @@ -11468,13 +11372,9 @@ static void lm_ggml_compute_forward_gelu_quick_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - LM_GGML_ASSERT(lm_ggml_is_contiguous_except_dim_1(src0)); - LM_GGML_ASSERT(lm_ggml_is_contiguous_except_dim_1(dst)); - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); const int ith = params->ith; const int nth = params->nth; @@ -11531,13 +11431,9 @@ static void lm_ggml_compute_forward_silu_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - LM_GGML_ASSERT(lm_ggml_is_contiguous_except_dim_1(src0)); - LM_GGML_ASSERT(lm_ggml_is_contiguous_except_dim_1(dst)); - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); const int ith = params->ith; const int nth = params->nth; @@ -11593,13 +11489,14 @@ static void lm_ggml_compute_forward_leaky_relu_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; @@ -11643,15 +11540,11 @@ static void lm_ggml_compute_forward_silu_back_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; const struct lm_ggml_tensor * grad = dst->src[1]; - LM_GGML_ASSERT(lm_ggml_is_contiguous_except_dim_1(grad)); - LM_GGML_ASSERT(lm_ggml_is_contiguous_except_dim_1(src0)); - LM_GGML_ASSERT(lm_ggml_is_contiguous_except_dim_1(dst)); - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, grad)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } + assert(lm_ggml_is_contiguous_1(grad)); + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + assert(lm_ggml_are_same_shape(src0, grad)); const int ith = params->ith; const int nth = params->nth; @@ -11708,19 +11601,17 @@ static void lm_ggml_compute_forward_hardswish_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { lm_ggml_vec_hardswish_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11751,19 +11642,17 @@ static void lm_ggml_compute_forward_hardsigmoid_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { lm_ggml_vec_hardsigmoid_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11800,10 +11689,6 @@ static void lm_ggml_compute_forward_norm_f32( LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); const int ith = params->ith; @@ -11875,10 +11760,6 @@ static void lm_ggml_compute_forward_rms_norm_f32( LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); const int ith = params->ith; @@ -11946,10 +11827,6 @@ static void lm_ggml_compute_forward_rms_norm_back_f32( LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst) && lm_ggml_are_same_shape(src0, src1)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); const int ith = params->ith; @@ -12124,10 +12001,6 @@ static void lm_ggml_compute_forward_group_norm_f32( LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); const int ith = params->ith; @@ -12214,39 +12087,6 @@ static void lm_ggml_compute_forward_group_norm( // lm_ggml_compute_forward_mul_mat -#if defined(LM_GGML_USE_ACCELERATE) || defined(LM_GGML_USE_OPENBLAS) -// helper function to determine if it is better to use BLAS or not -// for large matrices, BLAS is faster -static bool lm_ggml_compute_forward_mul_mat_use_blas(struct lm_ggml_tensor * dst) { - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - //const int64_t ne00 = src0->ne[0]; - //const int64_t ne01 = src0->ne[1]; - - const int64_t ne10 = src1->ne[0]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - - // NOTE: with LM_GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float) - // all the experts for each batch element and the processing would become incredibly slow - // TODO: find the optimal values for these - if (dst->op != LM_GGML_OP_MUL_MAT_ID && - lm_ggml_is_contiguous(src0) && - lm_ggml_is_contiguous(src1) && - //src0->type == LM_GGML_TYPE_F32 && - src1->type == LM_GGML_TYPE_F32 && - (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { - - /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ - return true; - } - - return false; -} -#endif - static void lm_ggml_compute_forward_mul_mat_one_chunk( const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst, @@ -12265,8 +12105,8 @@ static void lm_ggml_compute_forward_mul_mat_one_chunk( const bool src1_cont = lm_ggml_is_contiguous(src1); - lm_ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; - enum lm_ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + lm_ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; + enum lm_ggml_type const vec_dot_type = type_traits[type].vec_dot_type; // broadcast factors const int64_t r2 = ne12 / ne02; @@ -12340,15 +12180,11 @@ static void lm_ggml_compute_forward_mul_mat_one_chunk( static void lm_ggml_compute_forward_mul_mat( const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst, - struct lm_ggml_compute_state * state) { + struct lm_ggml_tensor * dst) { const struct lm_ggml_tensor * src0 = dst->src[0]; const struct lm_ggml_tensor * src1 = dst->src[1]; - int64_t t0 = lm_ggml_perf_time_us(); - UNUSED(t0); - LM_GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; @@ -12356,9 +12192,14 @@ static void lm_ggml_compute_forward_mul_mat( const enum lm_ggml_type type = src0->type; - enum lm_ggml_type const vec_dot_type = type_traits[type].vec_dot_type; - lm_ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; - int64_t const vec_dot_num_rows = type_traits[type].nrows; + enum lm_ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + lm_ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; + lm_ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat; + int64_t const vec_dot_num_rows = type_traits[type].nrows; + int64_t const matmul_num_cols = type_traits[type].ncols; + int64_t const blck_size_interleave = type_traits[type].blck_size_interleave; + lm_ggml_gemv_t const gemv = type_traits[type].gemv; + lm_ggml_gemm_t const gemm = type_traits[type].gemm; LM_GGML_ASSERT(ne0 == ne01); LM_GGML_ASSERT(ne1 == ne11); @@ -12375,92 +12216,14 @@ static void lm_ggml_compute_forward_mul_mat( LM_GGML_ASSERT(nb1 <= nb2); LM_GGML_ASSERT(nb2 <= nb3); - // broadcast factors - const int64_t r2 = ne12 / ne02; - const int64_t r3 = ne13 / ne03; - UNUSED(r2); - UNUSED(r3); - // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(LM_GGML_USE_CLBLAST) - if (lm_ggml_cl_can_mul_mat(src0, src1, dst)) { - if (params->ith == 0 && params->type == LM_GGML_TASK_TYPE_COMPUTE) { - lm_ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); - } - return; - } -#endif - -#if defined(LM_GGML_USE_ACCELERATE) || defined(LM_GGML_USE_OPENBLAS) - if (lm_ggml_compute_forward_mul_mat_use_blas(dst)) { - const int64_t ne_plane = ne01*ne00; - const size_t desired_wsize = ne13*ne12*ne_plane*sizeof(float); - UNUSED(desired_wsize); - - if (params->type == LM_GGML_TASK_TYPE_INIT) { - if (type != LM_GGML_TYPE_F32) { - assert(params->wsize >= desired_wsize); - // parallelize by src0 rows - for (int64_t i13 = 0; i13 < ne13; i13++) { - for (int64_t i12 = 0; i12 < ne12; i12++) { - // broadcast src0 into src1 across 2nd,3rd dimension - const int64_t i03 = i13/r3; - const int64_t i02 = i12/r2; - - const void * x = (char *) src0->data + i02*nb02 + i03*nb03; - float * const wdata = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane; - lm_ggml_to_float_t const to_float = type_traits[type].to_float; - - for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - to_float((const char *) x + i01*nb01, wdata + i01*ne00, ne00); - } - } - } - } - return; - } - - if (params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - - // perform sgemm, parallelization controlled by blas lib - if (ith != 0) { - return; - } - - //const int64_t tgemm0 = lm_ggml_perf_time_us(); - for (int64_t i13 = 0; i13 < ne13; i13++) { - for (int64_t i12 = 0; i12 < ne12; i12++) { - const int64_t i03 = i13/r3; - const int64_t i02 = i12/r2; - - const void * x = (char *) src0->data + i02*nb02 + i03*nb03; - const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); - float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); - - if (type != LM_GGML_TYPE_F32) { - x = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane; - } - - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - ne1, ne01, ne10, - 1.0f, y, ne10, - x, ne00, - 0.0f, d, ne01); - } - } - //printf("cblas_sgemm = %.3f ms, %lld flops\n", (lm_ggml_perf_time_us() - tgemm0)/1000.0, ne13*ne12*ne1*ne01*ne10*2); - - //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (lm_ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); - - return; - } -#endif - #if LM_GGML_USE_LLAMAFILE + // broadcast factors + const int64_t r2 = ne12 / ne02; + const int64_t r3 = ne13 / ne03; + const bool src1_cont = lm_ggml_is_contiguous(src1); if (src1_cont) { @@ -12474,7 +12237,6 @@ static void lm_ggml_compute_forward_mul_mat( (char *)dst->data + i12*nb2 + i13*nb3, nb1/lm_ggml_type_size(dst->type), ith, nth, - params->type, src0->type, src1->type, dst->type)) @@ -12484,36 +12246,43 @@ static void lm_ggml_compute_forward_mul_mat( UseGgmlGemm1:; #endif - if (params->type == LM_GGML_TASK_TYPE_INIT) { - if (ith != 0) { - return; - } - // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. - atomic_store(&state->shared->current_chunk, nth); - if (src1->type != vec_dot_type) { - char * wdata = params->wdata; - const size_t row_size = lm_ggml_row_size(vec_dot_type, ne10); - - assert(params->wsize >= ne11*ne12*ne13*row_size); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - - for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = 0; i11 < ne11; ++i11) { - from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - wdata += row_size; + if (src1->type != vec_dot_type) { + char * wdata = params->wdata; + + const size_t nbw1 = lm_ggml_row_size(vec_dot_type, ne10); + const size_t nbw2 = nbw1*ne11; + const size_t nbw3 = nbw2*ne12; + + assert(params->wsize >= ne13*nbw3); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + int64_t i11_processed = 0; + if ((lm_ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) { + for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { + from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + 4, ne10, blck_size_interleave); } + i11_processed = ne11 - ne11 % 4; + } + for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); } } } - - return; } - if (params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; + if (ith == 0) { + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + atomic_store(¶ms->shared->current_chunk, nth); } + lm_ggml_barrier(params->shared); + #if LM_GGML_USE_LLAMAFILE if (src1->type != vec_dot_type) { const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; @@ -12529,7 +12298,6 @@ UseGgmlGemm1:; (char *)dst->data + i12*nb2 + i13*nb3, nb1/lm_ggml_type_size(dst->type), ith, nth, - params->type, src0->type, vec_dot_type, dst->type)) @@ -12539,11 +12307,6 @@ UseGgmlGemm1:; UseGgmlGemm2:; #endif -#ifdef LM_GGML_PERF - int chunks_executed = 0; - UNUSED(chunks_executed); -#endif - // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) const int64_t nr0 = ne0; @@ -12585,8 +12348,27 @@ UseGgmlGemm2:; const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - //if (ith == 0) - // printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1); + if ((lm_ggml_n_dims(src0) == 2) && gemv) { + const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t src1_col_stride = lm_ggml_is_contiguous(src1) || src1->type != vec_dot_type ? lm_ggml_row_size(vec_dot_type, ne10) : nb11; + int64_t src0_start = (ith * ne01) / nth; + int64_t src0_end = ((ith + 1) * ne01) / nth; + src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start; + src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end; + if (src0_start >= src0_end) return; + + // If there are more than three rows in src1, use gemm; otherwise, use gemv. + if (gemm && (ne11 > 3)) { + gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); + } + for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) { + gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1, + src0_end - src0_start); + } + return; + } // The first chunk comes from our thread_id, the rest will get auto-assigned. int current_chunk = ith; @@ -12603,23 +12385,12 @@ UseGgmlGemm2:; lm_ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); -#ifdef LM_GGML_PERF - chunks_executed++; -#endif - if (nth >= nchunk0 * nchunk1) { break; } - current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1); + current_chunk = atomic_fetch_add(¶ms->shared->current_chunk, 1); } - -#ifdef LM_GGML_PERF - // These numbers are useful when trying to measure how well the threading scheduling works. - //int64_t workSize = (ne01 * ne11 * ne12 * ne13 * ne00) / nchunk0 / nchunk1; - //float time = (lm_ggml_perf_time_us() - t0); - //printf("MUL_MAT = %f ms, [%d, %d, %d, %d] x [%d, %d, %d, %d] = %I64u, %f ops/usec in %d chunks.\n", time / 1000.0, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, workSize, (float)workSize/time, chunks_executed); -#endif } // lm_ggml_compute_forward_mul_mat_id @@ -12641,9 +12412,11 @@ static void lm_ggml_compute_forward_mul_mat_id( const bool src1_cont = lm_ggml_is_contiguous(src1); - lm_ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; - enum lm_ggml_type const vec_dot_type = type_traits[type].vec_dot_type; - lm_ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; + lm_ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; + enum lm_ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + lm_ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; + int64_t const matmul_num_cols = type_traits[type].ncols; + lm_ggml_gemv_t const gemv = type_traits[type].gemv; // we don't support permuted src0 or src1 LM_GGML_ASSERT(nb00 == lm_ggml_type_size(type)); @@ -12671,32 +12444,33 @@ static void lm_ggml_compute_forward_mul_mat_id( int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11] - if (params->type == LM_GGML_TASK_TYPE_INIT) { - if (ith != 0) { - return; - } + if (src1->type != vec_dot_type) { char * wdata = params->wdata; - if (src1->type != vec_dot_type) { - const size_t row_size = lm_ggml_row_size(vec_dot_type, ne10); - assert(params->wsize >= ne11*ne12*ne13*row_size); - assert(src1->type == LM_GGML_TYPE_F32); + const size_t nbw1 = lm_ggml_row_size(vec_dot_type, ne10); + const size_t nbw2 = nbw1*ne11; + const size_t nbw3 = nbw2*ne12; - for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = 0; i11 < ne11; ++i11) { - from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - wdata += row_size; - } + assert(params->wsize >= ne13*nbw3); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = ith; i11 < ne11; i11 += nth) { + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); } } } + } + +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] + if (ith == 0) { // initialize matrix_row_counts memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); -#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] - // group rows by src0 matrix for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { for (int id = 0; id < n_ids; ++id) { @@ -12708,13 +12482,9 @@ static void lm_ggml_compute_forward_mul_mat_id( matrix_row_counts[i02] += 1; } } - - return; } - if (params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } + lm_ggml_barrier(params->shared); // compute each matrix multiplication in sequence for (int cur_a = 0; cur_a < n_as; ++cur_a) { @@ -12732,6 +12502,34 @@ static void lm_ggml_compute_forward_mul_mat_id( const int64_t nr0 = ne01; // src0 rows const int64_t nr1 = cne1; // src1 rows + if (((lm_ggml_n_dims(src0) - 1) == 2) && gemv) { + int64_t src0_cur_start = (ith * ne01) / nth; + int64_t src0_cur_end = ((ith + 1) * ne01) / nth; + src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start; + src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end; + if (src0_cur_start >= src0_cur_end) return; + + for (int ir1 = 0; ir1 < nr1; ir1++) { + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12 * ne11) * row_size + : (i11 * nb11 + i12 * nb12)); + + gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, + (const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start); + } + continue; + } + // distribute the thread work across the inner or outer loop based on which one is larger const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows @@ -12813,9 +12611,6 @@ static void lm_ggml_compute_forward_out_prod_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; const struct lm_ggml_tensor * src1 = dst->src[1]; - // int64_t t0 = lm_ggml_perf_time_us(); - // UNUSED(t0); - LM_GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; @@ -12840,75 +12635,10 @@ static void lm_ggml_compute_forward_out_prod_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows - // TODO: #if defined(LM_GGML_USE_CLBLAST) - -#if defined(LM_GGML_USE_ACCELERATE) || defined(LM_GGML_USE_OPENBLAS) - bool use_blas = lm_ggml_is_matrix(src0) && - lm_ggml_is_matrix(src1) && - lm_ggml_is_contiguous(src0) && - (lm_ggml_is_contiguous(src1) || lm_ggml_is_transposed(src1)); -#endif - - if (params->type == LM_GGML_TASK_TYPE_INIT) { -#if defined(LM_GGML_USE_ACCELERATE) || defined(LM_GGML_USE_OPENBLAS) // gemm beta will zero dst - if (use_blas) { - return; - } -#endif - if (ith != 0) { - return; - } + if (ith == 0) { lm_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); - return; - } - - if (params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - -#if defined(LM_GGML_USE_ACCELERATE) || defined(LM_GGML_USE_OPENBLAS) - if (use_blas) { - if (params->ith != 0) { // All threads other than the first do no work. - return; - } - // Arguments to lm_ggml_compute_forward_out_prod (expressed as major,minor) - // src0: (k,n) - // src1: (k,m) - // dst: (m,n) - // - // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f) - // Also expressed as (major,minor) - // a: (m,k): so src1 transposed - // b: (k,n): so src0 - // c: (m,n) - // - // However, if lm_ggml_is_transposed(src1) is true, then - // src1->data already contains a transposed version, so sgemm mustn't - // transpose it further. - - int n = src0->ne[0]; - int k = src0->ne[1]; - int m = src1->ne[0]; - - int transposeA, lda; - - if (!lm_ggml_is_transposed(src1)) { - transposeA = CblasTrans; - lda = m; - } else { - transposeA = CblasNoTrans; - lda = k; - } - - float * a = (float *) ((char *) src1->data); - float * b = (float *) ((char *) src0->data); - float * c = (float *) ((char *) dst->data); - - cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n); - - return; } -#endif + lm_ggml_barrier(params->shared); // dst[:,:,:,:] = 0 // for i2,i3: @@ -12984,19 +12714,6 @@ static void lm_ggml_compute_forward_out_prod_f32( } } } - - //int64_t t1 = lm_ggml_perf_time_us(); - //static int64_t acc = 0; - //acc += t1 - t0; - //if (t1 - t0 > 10) { - // printf("\n"); - // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); - // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); - // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); - // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); - - // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); - //} } static void lm_ggml_compute_forward_out_prod_q_f32( @@ -13006,9 +12723,6 @@ static void lm_ggml_compute_forward_out_prod_q_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; const struct lm_ggml_tensor * src1 = dst->src[1]; - // int64_t t0 = lm_ggml_perf_time_us(); - // UNUSED(t0); - LM_GGML_TENSOR_BINARY_OP_LOCALS; const int ith = params->ith; @@ -13039,19 +12753,10 @@ static void lm_ggml_compute_forward_out_prod_q_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows - // TODO: #if defined(LM_GGML_USE_ACCELERATE) || defined(LM_GGML_USE_OPENBLAS) || defined(LM_GGML_USE_CLBLAST) - - if (params->type == LM_GGML_TASK_TYPE_INIT) { - if (ith != 0) { - return; - } + if (ith == 0) { lm_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); - return; - } - - if (params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; } + lm_ggml_barrier(params->shared); // parallelize by last three dimensions @@ -13098,19 +12803,6 @@ static void lm_ggml_compute_forward_out_prod_q_f32( lm_ggml_vec_mad_f32(ne0, d, wdata, *s1); } } - - //int64_t t1 = lm_ggml_perf_time_us(); - //static int64_t acc = 0; - //acc += t1 - t0; - //if (t1 - t0 > 10) { - // printf("\n"); - // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); - // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); - // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); - // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); - - // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); - //} } static void lm_ggml_compute_forward_out_prod( @@ -13139,6 +12831,9 @@ static void lm_ggml_compute_forward_out_prod( case LM_GGML_TYPE_IQ4_XS: case LM_GGML_TYPE_IQ3_S: case LM_GGML_TYPE_IQ2_S: + case LM_GGML_TYPE_Q4_0_4_4: + case LM_GGML_TYPE_Q4_0_4_8: + case LM_GGML_TYPE_Q4_0_8_8: { lm_ggml_compute_forward_out_prod_q_f32(params, dst); } break; @@ -13170,10 +12865,6 @@ static void lm_ggml_compute_forward_scale_f32( LM_GGML_ASSERT(lm_ggml_is_contiguous(dst)); LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - // scale factor float v; memcpy(&v, dst->op_params, sizeof(float)); @@ -13242,20 +12933,16 @@ static void lm_ggml_compute_forward_set_f32( size_t offset = ((int32_t *) dst->op_params)[3]; bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - if (!inplace && (params->type == LM_GGML_TASK_TYPE_INIT)) { - if (params->ith != 0) { - return; + if (!inplace) { + if (params->ith == 0) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memcpy( + ((char *) dst->data), + ((char *) src0->data), + lm_ggml_nbytes(dst)); } - // memcpy needs to be synchronized across threads to avoid race conditions. - // => do it in INIT phase - memcpy( - ((char *) dst->data), - ((char *) src0->data), - lm_ggml_nbytes(dst)); - } - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; + lm_ggml_barrier(params->shared); } const int ith = params->ith; @@ -13332,6 +13019,9 @@ static void lm_ggml_compute_forward_set( case LM_GGML_TYPE_IQ4_XS: case LM_GGML_TYPE_IQ3_S: case LM_GGML_TYPE_IQ2_S: + case LM_GGML_TYPE_Q4_0_4_4: + case LM_GGML_TYPE_Q4_0_4_8: + case LM_GGML_TYPE_Q4_0_8_8: default: { LM_GGML_ASSERT(false); @@ -13404,10 +13094,6 @@ static void lm_ggml_compute_forward_get_rows_q( const struct lm_ggml_tensor * src0 = dst->src[0]; const struct lm_ggml_tensor * src1 = dst->src[1]; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_TENSOR_BINARY_OP_LOCALS const int64_t nc = ne00; @@ -13437,6 +13123,8 @@ static void lm_ggml_compute_forward_get_rows_q( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + assert(i01 >= 0 && i01 < ne01); + dequantize_row_q( (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); @@ -13450,10 +13138,6 @@ static void lm_ggml_compute_forward_get_rows_f16( const struct lm_ggml_tensor * src0 = dst->src[0]; const struct lm_ggml_tensor * src1 = dst->src[1]; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_TENSOR_BINARY_OP_LOCALS const int64_t nc = ne00; @@ -13480,6 +13164,8 @@ static void lm_ggml_compute_forward_get_rows_f16( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + assert(i01 >= 0 && i01 < ne01); + lm_ggml_fp16_to_fp32_row( (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); @@ -13493,10 +13179,6 @@ static void lm_ggml_compute_forward_get_rows_bf16( const struct lm_ggml_tensor * src0 = dst->src[0]; const struct lm_ggml_tensor * src1 = dst->src[1]; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_TENSOR_BINARY_OP_LOCALS const int64_t nc = ne00; @@ -13523,7 +13205,9 @@ static void lm_ggml_compute_forward_get_rows_bf16( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - lm_ggml_bf16_to_fp32_row( + assert(i01 >= 0 && i01 < ne01); + + lm_ggml_bf16_to_fp32_row( (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); } @@ -13536,10 +13220,6 @@ static void lm_ggml_compute_forward_get_rows_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; const struct lm_ggml_tensor * src1 = dst->src[1]; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_TENSOR_BINARY_OP_LOCALS const int64_t nc = ne00; @@ -13566,6 +13246,8 @@ static void lm_ggml_compute_forward_get_rows_f32( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + assert(i01 >= 0 && i01 < ne01); + lm_ggml_vec_cpy_f32(nc, (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); @@ -13599,6 +13281,9 @@ static void lm_ggml_compute_forward_get_rows( case LM_GGML_TYPE_IQ4_XS: case LM_GGML_TYPE_IQ3_S: case LM_GGML_TYPE_IQ2_S: + case LM_GGML_TYPE_Q4_0_4_4: + case LM_GGML_TYPE_Q4_0_4_8: + case LM_GGML_TYPE_Q4_0_8_8: { lm_ggml_compute_forward_get_rows_q(params, dst); } break; @@ -13649,21 +13334,15 @@ static void lm_ggml_compute_forward_get_rows_back_f32_f16( const struct lm_ggml_tensor * src0 = dst->src[0]; const struct lm_ggml_tensor * src1 = dst->src[1]; - LM_GGML_ASSERT(params->ith == 0); + if (params->ith != 0) { + return; + } + LM_GGML_ASSERT(lm_ggml_is_contiguous(dst)); // lm_ggml_compute_forward_dup_same_cont(params, opt0, dst); - if (params->type == LM_GGML_TASK_TYPE_INIT) { - if (params->ith != 0) { - return; - } - memset(dst->data, 0, lm_ggml_nbytes(dst)); - } - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } + memset(dst->data, 0, lm_ggml_nbytes(dst)); const int nc = src0->ne[0]; const int nr = lm_ggml_nelements(src1); @@ -13688,21 +13367,15 @@ static void lm_ggml_compute_forward_get_rows_back_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; const struct lm_ggml_tensor * src1 = dst->src[1]; - LM_GGML_ASSERT(params->ith == 0); + if (params->ith != 0) { + return; + } + LM_GGML_ASSERT(lm_ggml_is_contiguous(dst)); // lm_ggml_compute_forward_dup_same_cont(params, opt0, dst); - if (params->type == LM_GGML_TASK_TYPE_INIT) { - if (params->ith != 0) { - return; - } - memset(dst->data, 0, lm_ggml_nbytes(dst)); - } - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } + memset(dst->data, 0, lm_ggml_nbytes(dst)); const int nc = src0->ne[0]; const int nr = lm_ggml_nelements(src1); @@ -13768,9 +13441,7 @@ static void lm_ggml_compute_forward_diag_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - LM_GGML_ASSERT(params->ith == 0); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -13839,22 +13510,18 @@ static void lm_ggml_compute_forward_diag_mask_f32( LM_GGML_ASSERT(n_past >= 0); - if (!inplace && (params->type == LM_GGML_TASK_TYPE_INIT)) { - if (ith != 0) { - return; + if (!inplace) { + if (ith == 0) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0)); + memcpy( + ((char *) dst->data), + ((char *) src0->data), + lm_ggml_nbytes(dst)); } - // memcpy needs to be synchronized across threads to avoid race conditions. - // => do it in INIT phase - LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0)); - memcpy( - ((char *) dst->data), - ((char *) src0->data), - lm_ggml_nbytes(dst)); - } - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; + lm_ggml_barrier(params->shared); } // TODO: handle transposed/permuted matrices @@ -13926,10 +13593,6 @@ static void lm_ggml_compute_forward_soft_max_f32( assert(lm_ggml_is_contiguous(dst)); assert(lm_ggml_are_same_shape(src0, dst)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - float scale = 1.0f; float max_bias = 0.0f; @@ -14036,6 +13699,7 @@ static void lm_ggml_compute_forward_soft_max( } } + // lm_ggml_compute_forward_soft_max_back static void lm_ggml_compute_forward_soft_max_back_f32( @@ -14051,10 +13715,6 @@ static void lm_ggml_compute_forward_soft_max_back_f32( LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); LM_GGML_ASSERT(lm_ggml_are_same_shape(src1, dst)); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - // TODO: handle transposed/permuted matrices const int ith = params->ith; @@ -14143,9 +13803,7 @@ static void lm_ggml_compute_forward_clamp_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -14213,6 +13871,9 @@ static void lm_ggml_compute_forward_clamp( case LM_GGML_TYPE_IQ3_S: case LM_GGML_TYPE_IQ2_S: case LM_GGML_TYPE_Q8_K: + case LM_GGML_TYPE_Q4_0_4_4: + case LM_GGML_TYPE_Q4_0_4_8: + case LM_GGML_TYPE_Q4_0_8_8: case LM_GGML_TYPE_I8: case LM_GGML_TYPE_I16: case LM_GGML_TYPE_I32: @@ -14236,8 +13897,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static void rope_yarn( float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, - float * cos_theta, float * sin_theta -) { + float * cos_theta, float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; @@ -14254,18 +13914,19 @@ static void rope_yarn( // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float lm_ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) { - return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); +static float lm_ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); } static void lm_ggml_rope_cache_init( - float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, - float * cache, float sin_sign, float theta_scale -) { + float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, + float * cache, float sin_sign, float theta_scale) { + // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py float theta = theta_base; for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0/2] : 1.0f; rope_yarn( - theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] + theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] ); cache[i0 + 1] *= sin_sign; @@ -14274,11 +13935,11 @@ static void lm_ggml_rope_cache_init( } LM_GGML_CALL void lm_ggml_rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] ) { // start and end correction dims - float start = floorf(lm_ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)); - float end = ceilf(lm_ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)); + float start = floorf(lm_ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base)); + float end = ceilf(lm_ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base)); dims[0] = MAX(0, start); dims[1] = MIN(n_dims - 1, end); } @@ -14292,21 +13953,13 @@ static void lm_ggml_compute_forward_rope_f32( const struct lm_ggml_tensor * src1 = dst->src[1]; const struct lm_ggml_tensor * src2 = dst->src[2]; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - // these two only relevant for xPos RoPE: - float xpos_base; - bool xpos_down; - //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); @@ -14314,8 +13967,6 @@ static void lm_ggml_compute_forward_rope_f32( memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float)); - memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool)); LM_GGML_TENSOR_UNARY_OP_LOCALS @@ -14343,22 +13994,17 @@ static void lm_ggml_compute_forward_rope_f32( int ir = 0; const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float inv_ndims = -1.f/n_dims; + float corr_dims[2]; - lm_ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + lm_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & 2; - const bool is_glm = mode & 4; const float * freq_factors = NULL; - if (is_neox) { - if (src2 != NULL) { - LM_GGML_ASSERT(src2->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(src2->ne[0] >= n_dims / 2); - freq_factors = (const float *) src2->data; - } - } else { - LM_GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox"); + if (src2 != NULL) { + LM_GGML_ASSERT(src2->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; } // backward process uses inverse rotation by cos and sin. @@ -14373,101 +14019,50 @@ static void lm_ggml_compute_forward_rope_f32( const int64_t p = pos[i2]; float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox - lm_ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - } + lm_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - float theta_base = (float)p; - - if (is_glm) { - theta_base = MIN(p, n_ctx - 2); - float block_theta = MAX(p - (n_ctx - 2), 0); - for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta_base); - const float sin_theta = sinf(theta_base) * sin_sign; - const float cos_block_theta = cosf(block_theta); - const float sin_block_theta = sinf(block_theta) * sin_sign; - - theta_base *= theta_scale; - block_theta *= theta_scale; - - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - const float x2 = src[n_dims]; - const float x3 = src[n_dims/2*3]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta; - dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta; - } - } else if (!is_neox) { - for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + if (!is_neox) { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { const float cos_theta = cache[i0 + 0]; const float sin_theta = cache[i0 + 1]; - // zeta scaling for xPos only: - float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; - if (xpos_down) zeta = 1.0f / zeta; - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); const float x0 = src[0]; const float x1 = src[1]; - dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta; - dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta; + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; } } else { - // TODO: this might be wrong for ne0 != n_dims - need double check - // it seems we have to rope just the first n_dims elements and do nothing with the rest - // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 - theta_base *= freq_scale; - for (int64_t ic = 0; ic < ne0; ic += 2) { - if (ic < n_dims) { - const int64_t ib = 0; - - // simplified from `(ib * n_dims + ic) * inv_ndims` - float cur_rot = inv_ndims * ic - ib; - float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f; - - float cos_theta, sin_theta; - rope_yarn( - theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, - &cos_theta, &sin_theta - ); - sin_theta *= sin_sign; + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; - theta_base *= theta_scale; - - const int64_t i0 = ib*n_dims + ic/2; + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - const float x0 = src[0]; - const float x1 = src[n_dims/2]; + const float x0 = src[0]; + const float x1 = src[n_dims/2]; - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } else { - const int64_t i0 = ic; + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } + } - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } @@ -14484,17 +14079,13 @@ static void lm_ggml_compute_forward_rope_f16( const struct lm_ggml_tensor * src1 = dst->src[1]; const struct lm_ggml_tensor * src2 = dst->src[2]; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -14528,22 +14119,17 @@ static void lm_ggml_compute_forward_rope_f16( int ir = 0; const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float inv_ndims = -1.f/n_dims; + float corr_dims[2]; - lm_ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + lm_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & 2; - const bool is_glm = mode & 4; const float * freq_factors = NULL; - if (is_neox) { - if (src2 != NULL) { - LM_GGML_ASSERT(src2->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(src2->ne[0] >= n_dims / 2); - freq_factors = (const float *) src2->data; - } - } else { - LM_GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox"); + if (src2 != NULL) { + LM_GGML_ASSERT(src2->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; } // backward process uses inverse rotation by cos and sin. @@ -14558,43 +14144,14 @@ static void lm_ggml_compute_forward_rope_f16( const int64_t p = pos[i2]; float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox - lm_ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - } + lm_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - float theta_base = (float)p; - - if (is_glm) { - theta_base = MIN(p, n_ctx - 2); - float block_theta = MAX(p - (n_ctx - 2), 0); - for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta_base); - const float sin_theta = sinf(theta_base) * sin_sign; - const float cos_block_theta = cosf(block_theta); - const float sin_block_theta = sinf(block_theta) * sin_sign; - - theta_base *= theta_scale; - block_theta *= theta_scale; - - const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = LM_GGML_FP16_TO_FP32(src[0]); - const float x1 = LM_GGML_FP16_TO_FP32(src[n_dims/2]); - const float x2 = LM_GGML_FP16_TO_FP32(src[n_dims]); - const float x3 = LM_GGML_FP16_TO_FP32(src[n_dims/2*3]); - - dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[n_dims/2] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); - dst_data[n_dims] = LM_GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta); - dst_data[n_dims/2*3] = LM_GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta); - } - } else if (!is_neox) { - for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + if (!is_neox) { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { const float cos_theta = cache[i0 + 0]; const float sin_theta = cache[i0 + 1]; @@ -14608,47 +14165,29 @@ static void lm_ggml_compute_forward_rope_f16( dst_data[1] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); } } else { - // TODO: this might be wrong for ne0 != n_dims - need double check - // it seems we have to rope just the first n_dims elements and do nothing with the rest - // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 - theta_base *= freq_scale; - for (int64_t ic = 0; ic < ne0; ic += 2) { - if (ic < n_dims) { - const int64_t ib = 0; - - // simplified from `(ib * n_dims + ic) * inv_ndims` - float cur_rot = inv_ndims * ic - ib; - float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f; + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; - float cos_theta, sin_theta; - rope_yarn( - theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, - &cos_theta, &sin_theta - ); - sin_theta *= sin_sign; - - theta_base *= theta_scale; - - const int64_t i0 = ib*n_dims + ic/2; + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; - const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - const float x0 = LM_GGML_FP16_TO_FP32(src[0]); - const float x1 = LM_GGML_FP16_TO_FP32(src[n_dims/2]); + const float x0 = LM_GGML_FP16_TO_FP32(src[0]); + const float x1 = LM_GGML_FP16_TO_FP32(src[n_dims/2]); - dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[n_dims/2] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); - } else { - const int64_t i0 = ic; + dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } - const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } @@ -14714,9 +14253,6 @@ static void lm_ggml_compute_forward_conv_transpose_1d_f16_f32( LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); - int64_t t0 = lm_ggml_perf_time_us(); - UNUSED(t0); - LM_GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; @@ -14727,10 +14263,7 @@ static void lm_ggml_compute_forward_conv_transpose_1d_f16_f32( LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); LM_GGML_ASSERT(nb10 == sizeof(float)); - if (params->type == LM_GGML_TASK_TYPE_INIT) { - if (ith != 0) { - return; - } + if (ith == 0) { memset(params->wdata, 0, params->wsize); // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) @@ -14763,13 +14296,8 @@ static void lm_ggml_compute_forward_conv_transpose_1d_f16_f32( // need to zero dst since we are accumulating into it memset(dst->data, 0, lm_ggml_nbytes(dst)); - - return; - } - - if (params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; } + lm_ggml_barrier(params->shared); const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; @@ -14813,9 +14341,6 @@ static void lm_ggml_compute_forward_conv_transpose_1d_f32( LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); - int64_t t0 = lm_ggml_perf_time_us(); - UNUSED(t0); - LM_GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; @@ -14826,10 +14351,7 @@ static void lm_ggml_compute_forward_conv_transpose_1d_f32( LM_GGML_ASSERT(nb00 == sizeof(float)); LM_GGML_ASSERT(nb10 == sizeof(float)); - if (params->type == LM_GGML_TASK_TYPE_INIT) { - if (ith != 0) { - return; - } + if (ith == 0) { memset(params->wdata, 0, params->wsize); // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) @@ -14862,13 +14384,8 @@ static void lm_ggml_compute_forward_conv_transpose_1d_f32( // need to zero dst since we are accumulating into it memset(dst->data, 0, lm_ggml_nbytes(dst)); - - return; - } - - if (params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; } + lm_ggml_barrier(params->shared); const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; @@ -14937,9 +14454,6 @@ static void lm_ggml_compute_forward_im2col_f32( LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); - int64_t t0 = lm_ggml_perf_time_us(); - UNUSED(t0); - LM_GGML_TENSOR_BINARY_OP_LOCALS; const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; @@ -14970,14 +14484,6 @@ static void lm_ggml_compute_forward_im2col_f32( LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); LM_GGML_ASSERT(nb10 == sizeof(float)); - if (params->type == LM_GGML_TASK_TYPE_INIT) { - return; - } - - if (params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] { float * const wdata = (float *) dst->data; @@ -15025,9 +14531,6 @@ static void lm_ggml_compute_forward_im2col_f16( LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F16); - int64_t t0 = lm_ggml_perf_time_us(); - UNUSED(t0); - LM_GGML_TENSOR_BINARY_OP_LOCALS; const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; @@ -15058,14 +14561,6 @@ static void lm_ggml_compute_forward_im2col_f16( LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); LM_GGML_ASSERT(nb10 == sizeof(float)); - if (params->type == LM_GGML_TASK_TYPE_INIT) { - return; - } - - if (params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] { lm_ggml_fp16_t * const wdata = (lm_ggml_fp16_t *) dst->data; @@ -15131,9 +14626,6 @@ static void lm_ggml_compute_forward_conv_transpose_2d( LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); - int64_t t0 = lm_ggml_perf_time_us(); - UNUSED(t0); - LM_GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; @@ -15144,10 +14636,7 @@ static void lm_ggml_compute_forward_conv_transpose_2d( LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); LM_GGML_ASSERT(nb10 == sizeof(float)); - if (params->type == LM_GGML_TASK_TYPE_INIT) { - if (ith != 0) { - return; - } + if (ith == 0) { memset(params->wdata, 0, params->wsize); // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout) @@ -15182,13 +14671,8 @@ static void lm_ggml_compute_forward_conv_transpose_2d( } memset(dst->data, 0, lm_ggml_nbytes(dst)); - - return; - } - - if (params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; } + lm_ggml_barrier(params->shared); const int32_t stride = lm_ggml_get_op_params_i32(dst, 0); @@ -15236,9 +14720,8 @@ static void lm_ggml_compute_forward_pool_1d_sk_p0( const struct lm_ggml_tensor * src = dst->src[0]; assert(src->type == LM_GGML_TYPE_F32); - assert(params->ith == 0); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -15305,9 +14788,8 @@ static void lm_ggml_compute_forward_pool_2d( const struct lm_ggml_tensor * src = dst->src[0]; LM_GGML_ASSERT(src->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(params->ith == 0); - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -15380,10 +14862,6 @@ static void lm_ggml_compute_forward_upscale_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); const int ith = params->ith; @@ -15444,10 +14922,6 @@ static void lm_ggml_compute_forward_pad_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); LM_GGML_ASSERT( dst->nb[0] == sizeof(float)); @@ -15504,10 +14978,6 @@ static void lm_ggml_compute_forward_arange_f32( const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst) { - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_ASSERT(dst->nb[0] == sizeof(float)); const int ith = params->ith; @@ -15546,10 +15016,6 @@ static void lm_ggml_compute_forward_timestep_embedding_f32( const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst) { - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - const struct lm_ggml_tensor * src0 = dst->src[0]; LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); @@ -15605,10 +15071,6 @@ static void lm_ggml_compute_forward_argsort_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_TENSOR_UNARY_OP_LOCALS LM_GGML_ASSERT(nb0 == sizeof(float)); @@ -15669,8 +15131,6 @@ static void lm_ggml_compute_forward_flash_attn_ext_f16( const struct lm_ggml_tensor * v, const struct lm_ggml_tensor * mask, struct lm_ggml_tensor * dst) { - int64_t t0 = lm_ggml_perf_time_us(); - UNUSED(t0); LM_GGML_TENSOR_LOCALS(int64_t, neq, q, ne) LM_GGML_TENSOR_LOCALS(size_t, nbq, q, nb) @@ -15715,14 +15175,6 @@ static void lm_ggml_compute_forward_flash_attn_ext_f16( const int64_t rv2 = neq2/nev2; const int64_t rv3 = neq3/nev3; - if (params->type == LM_GGML_TASK_TYPE_INIT) { - return; - } - - if (params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - // parallelize by q rows using lm_ggml_vec_dot_f32 // total rows in q @@ -15905,9 +15357,6 @@ static void lm_ggml_compute_forward_flash_attn_back_f32( const struct lm_ggml_tensor * v = dst->src[2]; const struct lm_ggml_tensor * d = dst->src[3]; - int64_t t0 = lm_ggml_perf_time_us(); - UNUSED(t0); - LM_GGML_TENSOR_LOCALS(int64_t, neq, q, ne) LM_GGML_TENSOR_LOCALS(size_t, nbq, q, nb) LM_GGML_TENSOR_LOCALS(int64_t, nek, k, ne) @@ -15954,16 +15403,10 @@ static void lm_ggml_compute_forward_flash_attn_back_f32( LM_GGML_ASSERT(nb1 <= nb2); LM_GGML_ASSERT(nb2 <= nb3); - if (params->type == LM_GGML_TASK_TYPE_INIT) { - if (ith == 0) { - memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); - } - return; - } - - if (params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; + if (ith == 0) { + memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); } + lm_ggml_barrier(params->shared); const int64_t elem_q = lm_ggml_nelements(q); const int64_t elem_k = lm_ggml_nelements(k); @@ -16243,10 +15686,6 @@ static void lm_ggml_compute_forward_flash_attn_back( static void lm_ggml_compute_forward_ssm_conv_f32( const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst) { - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - const struct lm_ggml_tensor * src0 = dst->src[0]; // conv_state const struct lm_ggml_tensor * src1 = dst->src[1]; // x const struct lm_ggml_tensor * src2 = dst->src[2]; // conv1d.weight @@ -16369,10 +15808,6 @@ static void lm_ggml_compute_forward_ssm_conv( static void lm_ggml_compute_forward_ssm_scan_f32( const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst) { - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - const struct lm_ggml_tensor * src0 = dst->src[0]; // s const struct lm_ggml_tensor * src1 = dst->src[1]; // x const struct lm_ggml_tensor * src2 = dst->src[2]; // dt @@ -16494,13 +15929,10 @@ static void lm_ggml_compute_forward_ssm_scan( static void lm_ggml_compute_forward_win_part_f32( const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst) { + UNUSED(params); const struct lm_ggml_tensor * src0 = dst->src[0]; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) LM_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) @@ -16560,13 +15992,10 @@ static void lm_ggml_compute_forward_win_part( static void lm_ggml_compute_forward_win_unpart_f32( const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst) { + UNUSED(params); const struct lm_ggml_tensor * src0 = dst->src[0]; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - LM_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) LM_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) @@ -16692,13 +16121,10 @@ static void lm_ggml_compute_forward_unary( static void lm_ggml_compute_forward_get_rel_pos_f16( const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst) { + UNUSED(params); const struct lm_ggml_tensor * src0 = dst->src[0]; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 LM_GGML_TENSOR_UNARY_OP_LOCALS @@ -16748,20 +16174,12 @@ static void lm_ggml_compute_forward_add_rel_pos_f32( const struct lm_ggml_tensor * src2 = dst->src[2]; const bool inplace = (bool) ((int32_t *) dst->op_params)[0]; - if (!inplace && params->type == LM_GGML_TASK_TYPE_INIT) { - if (params->ith != 0) { - return; + if (!inplace) { + if (params->ith == 0) { + memcpy((char *) dst->data, (char *) src0->data, lm_ggml_nbytes(dst)); } - memcpy((char *) dst->data, (char *) src0->data, lm_ggml_nbytes(dst)); - return; - } - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; + lm_ggml_barrier(params->shared); } - - int64_t t0 = lm_ggml_perf_time_us(); - UNUSED(t0); - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359 float * src1_data = (float *) src1->data; @@ -16835,18 +16253,17 @@ static void lm_ggml_compute_forward_map_unary_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { fun(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -16883,20 +16300,18 @@ static void lm_ggml_compute_forward_map_binary_f32( const struct lm_ggml_tensor * src0 = dst->src[0]; const struct lm_ggml_tensor * src1 = dst->src[1]; - assert(params->ith == 0); - assert(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(src1)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); + const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - assert(src1->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { fun(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -16933,9 +16348,7 @@ static void lm_ggml_compute_forward_map_custom1_f32( const struct lm_ggml_tensor * a = dst->src[0]; - assert(params->ith == 0); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -16952,9 +16365,7 @@ static void lm_ggml_compute_forward_map_custom2_f32( const struct lm_ggml_tensor * a = dst->src[0]; const struct lm_ggml_tensor * b = dst->src[1]; - assert(params->ith == 0); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -16972,9 +16383,7 @@ static void lm_ggml_compute_forward_map_custom3_f32( const struct lm_ggml_tensor * b = dst->src[1]; const struct lm_ggml_tensor * c = dst->src[1]; - assert(params->ith == 0); - - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -16989,10 +16398,6 @@ static void lm_ggml_compute_forward_map_custom1( const struct lm_ggml_tensor * a = dst->src[0]; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - struct lm_ggml_map_custom1_op_params p; memcpy(&p, dst->op_params, sizeof(p)); @@ -17008,10 +16413,6 @@ static void lm_ggml_compute_forward_map_custom2( const struct lm_ggml_tensor * a = dst->src[0]; const struct lm_ggml_tensor * b = dst->src[1]; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - struct lm_ggml_map_custom2_op_params p; memcpy(&p, dst->op_params, sizeof(p)); @@ -17028,10 +16429,6 @@ static void lm_ggml_compute_forward_map_custom3( const struct lm_ggml_tensor * b = dst->src[1]; const struct lm_ggml_tensor * c = dst->src[2]; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - struct lm_ggml_map_custom3_op_params p; memcpy(&p, dst->op_params, sizeof(p)); @@ -17063,21 +16460,10 @@ static void lm_ggml_compute_forward_cross_entropy_loss_f32( LM_GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc)); - if (params->type == LM_GGML_TASK_TYPE_INIT) { - if (ith == 0) { - memset(sums, 0, sizeof(float) * (nth + nth * nc)); - } - return; - } - - if (params->type == LM_GGML_TASK_TYPE_FINALIZE) { - if (ith == 0) { - float * dp = (float *) dst->data; - lm_ggml_vec_sum_f32(nth, dp, sums); - dp[0] *= -1.0f / (float) nr; - } - return; + if (ith == 0) { + memset(sums, 0, sizeof(float) * (nth + nth * nc)); } + lm_ggml_barrier(params->shared); const double eps = 1e-9; @@ -17125,7 +16511,13 @@ static void lm_ggml_compute_forward_cross_entropy_loss_f32( } #endif } + lm_ggml_barrier(params->shared); + if (ith == 0) { + float * dp = (float *) dst->data; + lm_ggml_vec_sum_f32(nth, dp, sums); + dp[0] *= -1.0f / (float) nr; + } } static void lm_ggml_compute_forward_cross_entropy_loss( @@ -17165,10 +16557,6 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back_f32( const int64_t ith = params->ith; const int64_t nth = params->nth; - if (params->type == LM_GGML_TASK_TYPE_INIT || params->type == LM_GGML_TASK_TYPE_FINALIZE) { - return; - } - const double eps = 1e-9; // TODO: handle transposed/permuted matrices @@ -17239,7 +16627,7 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back( ///////////////////////////////// -static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * tensor, struct lm_ggml_compute_state * state) { +static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * tensor) { LM_GGML_ASSERT(params); if (tensor->op == LM_GGML_OP_NONE || lm_ggml_is_empty(tensor)) { @@ -17337,7 +16725,7 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru } break; case LM_GGML_OP_MUL_MAT: { - lm_ggml_compute_forward_mul_mat(params, tensor, state); + lm_ggml_compute_forward_mul_mat(params, tensor); } break; case LM_GGML_OP_MUL_MAT_ID: { @@ -18350,9 +17738,9 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; - const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_orig_ctx = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down; + //const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); @@ -18360,8 +17748,6 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); - memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float)); - memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool)); src0->grad = lm_ggml_add_or_set(ctx, src0->grad, @@ -18371,16 +17757,13 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm src2, n_dims, mode, - n_ctx, - n_orig_ctx, + n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, - beta_slow, - xpos_base, - xpos_down), + beta_slow), zero_table); } } break; @@ -18390,9 +17773,9 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; - const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_orig_ctx = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down; + //const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); @@ -18400,8 +17783,6 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); - memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float)); - memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool)); src0->grad = lm_ggml_add_or_set(ctx, src0->grad, @@ -18411,16 +17792,13 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm src2, n_dims, mode, - n_ctx, - n_orig_ctx, + n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, - xpos_base, - xpos_down, false), zero_table); } @@ -18825,9 +18203,6 @@ struct lm_ggml_cgraph * lm_ggml_new_graph_custom(struct lm_ggml_context * ctx, s /*.leafs =*/ leafs_ptr, /*.hash_table =*/ { hash_size, hash_keys_ptr }, /*.order =*/ LM_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, - /*.perf_runs =*/ 0, - /*.perf_cycles =*/ 0, - /*.perf_time_us =*/ 0, }; return cgraph; @@ -18847,9 +18222,6 @@ struct lm_ggml_cgraph lm_ggml_graph_view(struct lm_ggml_cgraph * cgraph0, int i0 /*.leafs =*/ NULL, /*.hash_table =*/ { 0, NULL }, /*.order =*/ cgraph0->order, - /*.perf_runs =*/ 0, - /*.perf_cycles =*/ 0, - /*.perf_time_us =*/ 0, }; return cgraph; @@ -19043,16 +18415,7 @@ static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); } static void clear_numa_thread_affinity(void) {} #endif -static void lm_ggml_graph_compute_perf_stats_node(struct lm_ggml_tensor * node, const struct lm_ggml_compute_state_shared * st) { - int64_t cycles_cur = lm_ggml_perf_cycles() - st->perf_node_start_cycles; - int64_t time_us_cur = lm_ggml_perf_time_us() - st->perf_node_start_time_us; - - node->perf_runs++; - node->perf_cycles += cycles_cur; - node->perf_time_us += time_us_cur; -} - -static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads, int n_cur_threads) { +static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) { int n_tasks = 0; if (lm_ggml_is_empty(node)) { @@ -19064,6 +18427,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads, int switch (node->op) { case LM_GGML_OP_CPY: case LM_GGML_OP_DUP: + case LM_GGML_OP_CONT: case LM_GGML_OP_ADD: case LM_GGML_OP_ADD1: case LM_GGML_OP_ACC: @@ -19094,8 +18458,8 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads, int case LM_GGML_UNARY_OP_ELU: case LM_GGML_UNARY_OP_RELU: case LM_GGML_UNARY_OP_SIGMOID: - case LM_GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads - case LM_GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads + case LM_GGML_UNARY_OP_HARDSWISH: + case LM_GGML_UNARY_OP_HARDSIGMOID: { n_tasks = 1; } break; @@ -19118,37 +18482,21 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads, int case LM_GGML_OP_RMS_NORM_BACK: case LM_GGML_OP_GROUP_NORM: case LM_GGML_OP_CONCAT: - { - n_tasks = n_threads; - } break; case LM_GGML_OP_MUL_MAT: - { - n_tasks = n_threads; - - // TODO: use different scheduling for different matrix sizes - //const int nr0 = lm_ggml_nrows(node->src[0]); - //const int nr1 = lm_ggml_nrows(node->src[1]); - - //n_tasks = MIN(n_threads, MAX(1, nr0/128)); - //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks); - } break; case LM_GGML_OP_MUL_MAT_ID: - { - n_tasks = n_threads; - } break; case LM_GGML_OP_OUT_PROD: { n_tasks = n_threads; } break; case LM_GGML_OP_GET_ROWS: { - // FIXME: the cost of launching additional threads decreases performance with GPU offloading - //n_tasks = MIN(n_threads, lm_ggml_nelements(node->src[1])); - n_tasks = MIN(n_cur_threads, lm_ggml_nelements(node->src[1])); + // FIXME: get_rows can use additional threads, but the cost of launching additional threads + // decreases performance with GPU offloading + //n_tasks = n_threads; + n_tasks = 1; } break; case LM_GGML_OP_SCALE: case LM_GGML_OP_SET: - case LM_GGML_OP_CONT: case LM_GGML_OP_RESHAPE: case LM_GGML_OP_VIEW: case LM_GGML_OP_PERMUTE: @@ -19175,14 +18523,8 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads, int { n_tasks = MIN(n_threads, lm_ggml_nrows(node->src[0])); } break; - case LM_GGML_OP_CONV_TRANSPOSE_1D: - { - n_tasks = n_threads; - } break; case LM_GGML_OP_IM2COL: - { - n_tasks = n_threads; - } break; + case LM_GGML_OP_CONV_TRANSPOSE_1D: case LM_GGML_OP_CONV_TRANSPOSE_2D: { n_tasks = n_threads; @@ -19193,33 +18535,12 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads, int n_tasks = 1; } break; case LM_GGML_OP_UPSCALE: - { - n_tasks = n_threads; - } break; case LM_GGML_OP_PAD: - { - n_tasks = n_threads; - } break; case LM_GGML_OP_ARANGE: - { - n_tasks = n_threads; - } break; case LM_GGML_OP_TIMESTEP_EMBEDDING: - { - n_tasks = n_threads; - } break; case LM_GGML_OP_ARGSORT: - { - n_tasks = n_threads; - } break; case LM_GGML_OP_FLASH_ATTN_EXT: - { - n_tasks = n_threads; - } break; case LM_GGML_OP_FLASH_ATTN_BACK: - { - n_tasks = n_threads; - } break; case LM_GGML_OP_SSM_CONV: case LM_GGML_OP_SSM_SCAN: { @@ -19267,9 +18588,6 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads, int } } break; case LM_GGML_OP_CROSS_ENTROPY_LOSS: - { - n_tasks = n_threads; - } break; case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK: { n_tasks = n_threads; @@ -19299,184 +18617,6 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads, int return n_tasks; } -static void lm_ggml_graph_compute_thread_sync_node(int * node_n, struct lm_ggml_compute_state * state, const bool do_yield) { - // wait for other threads to finish - const int last_node_n = * node_n; - - while (true) { - if (do_yield) { - sched_yield(); - } - - * node_n = atomic_load(&state->shared->node_n); - if (* node_n != last_node_n) break; -#if defined(__SSE3__) - // Tell the processor we're spinning. It's a processor hint for spinlocks. - _mm_pause(); -#endif - } -} - -static void lm_ggml_graph_compute_thread_sync_task(int * task_phase, struct lm_ggml_compute_state * state, const bool do_yield) { - // wait for other threads to finish - const int last_task_phase = * task_phase; - - while (true) { - if (do_yield) { - sched_yield(); - } - - * task_phase = atomic_load(&state->shared->node_task); - if (* task_phase != last_task_phase) break; -#if defined(__SSE3__) - // Tell the processor we're spinning. It's a processor hint for spinlocks. - _mm_pause(); -#endif - } -} - -static thread_ret_t lm_ggml_graph_compute_thread(void * data) { - struct lm_ggml_compute_state * state = (struct lm_ggml_compute_state *) data; - - const struct lm_ggml_cgraph * cgraph = state->shared->cgraph; - const struct lm_ggml_cplan * cplan = state->shared->cplan; - - const int n_threads = state->shared->n_threads; - - set_numa_thread_affinity(state->ith); - - int node_n = -1; - int task_phase = LM_GGML_TASK_TYPE_FINALIZE; - - while (true) { - if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { - state->shared->node_n += 1; - state->ec = LM_GGML_STATUS_ABORTED; - return 0; - } - - if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { - // all other threads are finished and spinning - // do finalize and init here so we don't have synchronize again - struct lm_ggml_compute_params params = { - /*.type =*/ LM_GGML_TASK_TYPE_FINALIZE, - /*.ith =*/ 0, - /*.nth =*/ 0, - /*.wsize =*/ cplan->work_size, - /*.wdata =*/ cplan->work_data, - }; - - if (node_n != -1) { - /* FINALIZE */ - struct lm_ggml_tensor * node = cgraph->nodes[node_n]; - if (LM_GGML_OP_HAS_FINALIZE[node->op]) { - params.nth = lm_ggml_get_n_tasks(node, n_threads, state->shared->n_threads); - lm_ggml_compute_forward(¶ms, node, state); - } - lm_ggml_graph_compute_perf_stats_node(node, state->shared); - } - - // distribute new work or execute it direct if 1T - while (++node_n < cgraph->n_nodes) { - LM_GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes); - struct lm_ggml_tensor * node = cgraph->nodes[node_n]; - const int n_tasks = lm_ggml_get_n_tasks(node, n_threads, state->shared->n_threads); - - state->shared->perf_node_start_cycles = lm_ggml_perf_cycles(); - state->shared->perf_node_start_time_us = lm_ggml_perf_time_us(); - - params.nth = n_tasks; - - if (n_tasks == 1) { - /* INIT */ - if (LM_GGML_OP_HAS_INIT[node->op]) { - params.type = LM_GGML_TASK_TYPE_INIT; - lm_ggml_compute_forward(¶ms, node, state); - } - - // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1, - // they do something more efficient than spinning (?) - params.type = LM_GGML_TASK_TYPE_COMPUTE; - lm_ggml_compute_forward(¶ms, node, state); - - if (LM_GGML_OP_HAS_FINALIZE[node->op]) { - params.type = LM_GGML_TASK_TYPE_FINALIZE; - lm_ggml_compute_forward(¶ms, node, state); - } - - lm_ggml_graph_compute_perf_stats_node(node, state->shared); - } else { - break; - } - - if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { - break; - } - } - - task_phase = LM_GGML_TASK_TYPE_INIT; - atomic_store(&state->shared->n_active, n_threads); - atomic_store(&state->shared->node_n, node_n); - atomic_store(&state->shared->node_task, task_phase); - } else { - lm_ggml_graph_compute_thread_sync_node(&node_n, state, false); - lm_ggml_graph_compute_thread_sync_task(&task_phase, state, false); - } - - // check if we should stop - if (node_n >= cgraph->n_nodes) break; - - /* INIT & COMPUTE */ - struct lm_ggml_tensor * node = cgraph->nodes[node_n]; - const int n_tasks = lm_ggml_get_n_tasks(node, n_threads, state->shared->n_threads); - - struct lm_ggml_compute_params params = { - /*.type =*/ LM_GGML_TASK_TYPE_INIT, - /*.ith =*/ state->ith, - /*.nth =*/ n_tasks, - /*.wsize =*/ cplan->work_size, - /*.wdata =*/ cplan->work_data, - }; - - if (state->ith < n_tasks) { - if (LM_GGML_OP_HAS_INIT[node->op]) { - lm_ggml_compute_forward(¶ms, node, state); - } - } - - if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { - task_phase = LM_GGML_TASK_TYPE_COMPUTE; - atomic_store(&state->shared->n_active, n_threads); - atomic_store(&state->shared->node_task, task_phase); - } - else { - // TODO: this sched_yield can have significant impact on the performance - either positive or negative - // depending on the workload and the operating system. - // since it is not clear what is the best approach, it should potentially become user-configurable - // ref: https://github.com/ggerganov/ggml/issues/291 - // UPD: adding the do_yield flag seems to resolve the issue universally - const bool do_yield = node_n < 0 || cgraph->nodes[node_n]->op == LM_GGML_OP_MUL_MAT; - lm_ggml_graph_compute_thread_sync_task(&task_phase, state, do_yield); - } - - if (state->ith < n_tasks) { - params.type = LM_GGML_TASK_TYPE_COMPUTE; - lm_ggml_compute_forward(¶ms, node, state); - } - - if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { - task_phase = LM_GGML_TASK_TYPE_FINALIZE; - atomic_store(&state->shared->n_active, n_threads); - atomic_store(&state->shared->node_task, task_phase); - } - else { - lm_ggml_graph_compute_thread_sync_task(&task_phase, state, false); - } - } - - return 0; -} - struct lm_ggml_cplan lm_ggml_graph_plan(const struct lm_ggml_cgraph * cgraph, int n_threads) { if (n_threads <= 0) { n_threads = LM_GGML_DEFAULT_N_THREADS; @@ -19493,7 +18633,7 @@ struct lm_ggml_cplan lm_ggml_graph_plan(const struct lm_ggml_cgraph * cgraph, in for (int i = 0; i < cgraph->n_nodes; i++) { struct lm_ggml_tensor * node = cgraph->nodes[i]; - const int n_tasks = lm_ggml_get_n_tasks(node, n_threads, 1); + const int n_tasks = lm_ggml_get_n_tasks(node, n_threads); max_tasks = MAX(max_tasks, n_tasks); @@ -19527,22 +18667,6 @@ struct lm_ggml_cplan lm_ggml_graph_plan(const struct lm_ggml_cgraph * cgraph, in { const enum lm_ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type; -#if defined(LM_GGML_USE_CLBLAST) - if (lm_ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) { - cur = lm_ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node); - } else -#endif -#if defined(LM_GGML_USE_ACCELERATE) || defined(LM_GGML_USE_OPENBLAS) - if (lm_ggml_compute_forward_mul_mat_use_blas(node)) { - if (node->src[0]->type != LM_GGML_TYPE_F32) { - // here we need memory for fully dequantized matrix from src0 - // take into account that src0 can be broadcasted into src1[2,3] - cur = lm_ggml_type_size(LM_GGML_TYPE_F32) - * node->src[0]->ne[0]*node->src[0]->ne[1] - * node->src[1]->ne[2]*node->src[1]->ne[3]; - } - } else -#endif if (node->src[1]->type != vec_dot_type) { cur = lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(node->src[1])); } @@ -19661,91 +18785,121 @@ struct lm_ggml_cplan lm_ggml_graph_plan(const struct lm_ggml_cgraph * cgraph, in return cplan; } -enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct lm_ggml_cplan * cplan) { - { - LM_GGML_ASSERT(cplan); - LM_GGML_ASSERT(cplan->n_threads > 0); +static thread_ret_t lm_ggml_graph_compute_thread(void * data) { + struct lm_ggml_compute_state * state = (struct lm_ggml_compute_state *) data; + + const struct lm_ggml_cgraph * cgraph = state->shared->cgraph; + const struct lm_ggml_cplan * cplan = state->shared->cplan; + + set_numa_thread_affinity(state->ith); + + struct lm_ggml_compute_params params = { + /*.ith =*/ state->ith, + /*.nth =*/ state->shared->n_threads, + /*.wsize =*/ cplan->work_size, + /*.wdata =*/ cplan->work_data, + /*.shared=*/ state->shared, + }; + + for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { + struct lm_ggml_tensor * node = cgraph->nodes[node_n]; + + lm_ggml_compute_forward(¶ms, node); + + if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { + state->shared->ec = LM_GGML_STATUS_ABORTED; + } - if (cplan->work_size > 0) { - LM_GGML_ASSERT(cplan->work_data); + lm_ggml_barrier(state->shared); + + if (state->shared->ec != LM_GGML_STATUS_SUCCESS) { + break; } } - const int n_threads = cplan->n_threads; + return 0; +} + +enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct lm_ggml_cplan * cplan) { + LM_GGML_ASSERT(cplan); + LM_GGML_ASSERT(cplan->n_threads > 0); + LM_GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL); + + int n_threads = cplan->n_threads; struct lm_ggml_compute_state_shared state_shared = { /*.cgraph =*/ cgraph, /*.cgraph_plan =*/ cplan, - /*.perf_node_start_cycles =*/ 0, - /*.perf_node_start_time_us =*/ 0, /*.n_threads =*/ n_threads, - /*.n_active =*/ n_threads, - /*.node_n =*/ -1, - /*.node_task =*/ LM_GGML_TASK_TYPE_FINALIZE, + /*.n_barrier =*/ 0, + /*.n_barrier_passed =*/ 0, /*.abort_callback =*/ NULL, /*.abort_callback_data =*/ NULL, - /*.current_chunk; =*/ 0, + /*.current_chunk =*/ 0, + /*.ec =*/ LM_GGML_STATUS_SUCCESS, }; - struct lm_ggml_compute_state * workers = alloca(sizeof(struct lm_ggml_compute_state)*n_threads); - // create thread pool +#ifdef LM_GGML_USE_OPENMP if (n_threads > 1) { - for (int j = 1; j < n_threads; ++j) { - workers[j] = (struct lm_ggml_compute_state) { + #pragma omp parallel num_threads(n_threads) + { + #pragma omp single + { + // update the number of threads from the actual number of threads that we got from OpenMP + n_threads = omp_get_num_threads(); + state_shared.n_threads = n_threads; + } + + struct lm_ggml_compute_state worker = { .thrd = 0, - .ith = j, + .ith = omp_get_thread_num(), .shared = &state_shared, - .ec = LM_GGML_STATUS_SUCCESS, }; - - const int rc = lm_ggml_thread_create(&workers[j].thrd, NULL, lm_ggml_graph_compute_thread, &workers[j]); - LM_GGML_ASSERT(rc == 0); - UNUSED(rc); + lm_ggml_graph_compute_thread(&worker); } + } else { + struct lm_ggml_compute_state worker = { + .thrd = 0, + .ith = 0, + .shared = &state_shared, + }; + lm_ggml_graph_compute_thread(&worker); } +#else + struct lm_ggml_compute_state * workers = alloca(sizeof(struct lm_ggml_compute_state)*n_threads); - workers[0].ith = 0; - workers[0].shared = &state_shared; - workers[0].ec = LM_GGML_STATUS_SUCCESS; + for (int j = 0; j < n_threads; ++j) { + workers[j] = (struct lm_ggml_compute_state) { + .thrd = 0, + .ith = j, + .shared = &state_shared, + }; + } - const int64_t perf_start_cycles = lm_ggml_perf_cycles(); - const int64_t perf_start_time_us = lm_ggml_perf_time_us(); + // create thread pool + for (int j = 1; j < n_threads; ++j) { + const int rc = lm_ggml_thread_create(&workers[j].thrd, NULL, lm_ggml_graph_compute_thread, &workers[j]); + LM_GGML_ASSERT(rc == 0); + UNUSED(rc); + } // this is a work thread too lm_ggml_graph_compute_thread(&workers[0]); - enum lm_ggml_status compute_status = workers[0].ec; - - // don't leave affinity set on the main thread - clear_numa_thread_affinity(); // join or kill thread pool if (n_threads > 1) { for (int j = 1; j < n_threads; j++) { const int rc = lm_ggml_thread_join(workers[j].thrd, NULL); LM_GGML_ASSERT(rc == 0); - if (workers[j].ec != LM_GGML_STATUS_SUCCESS) - compute_status = workers[j].ec; + UNUSED(rc); } } +#endif - // performance stats (graph) - { - int64_t perf_cycles_cur = lm_ggml_perf_cycles() - perf_start_cycles; - int64_t perf_time_us_cur = lm_ggml_perf_time_us() - perf_start_time_us; - - cgraph->perf_runs++; - cgraph->perf_cycles += perf_cycles_cur; - cgraph->perf_time_us += perf_time_us_cur; - - LM_GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n", - __func__, cgraph->perf_runs, - (double) perf_cycles_cur / (double) lm_ggml_cycles_per_ms(), - (double) cgraph->perf_cycles / (double) lm_ggml_cycles_per_ms() / (double) cgraph->perf_runs, - (double) perf_time_us_cur / 1000.0, - (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs); - } + // don't leave affinity set on the main thread + clear_numa_thread_affinity(); - return compute_status; + return state_shared.ec; } enum lm_ggml_status lm_ggml_graph_compute_with_ctx(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, int n_threads) { @@ -19865,7 +19019,7 @@ void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fna FILE * fout = lm_ggml_fopen(fname, "wb"); if (!fout) { - fprintf(stderr, "%s: failed to open %s\n", __func__, fname); + fprintf(stderr, "%s: failed to open %s: %s\n", __func__, fname, strerror(errno)); return; } @@ -20002,7 +19156,7 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_ { FILE * fin = lm_ggml_fopen(fname, "rb"); if (!fin) { - fprintf(stderr, "%s: failed to open %s\n", __func__, fname); + fprintf(stderr, "%s: failed to open %s: %s\n", __func__, fname, strerror(errno)); return result; } @@ -20244,24 +19398,16 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_ } void lm_ggml_graph_print(const struct lm_ggml_cgraph * cgraph) { - int64_t perf_total_per_op_us[LM_GGML_OP_COUNT] = {0}; - LM_GGML_PRINT("=== GRAPH ===\n"); LM_GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes); for (int i = 0; i < cgraph->n_nodes; i++) { struct lm_ggml_tensor * node = cgraph->nodes[i]; - perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us); - - LM_GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", + LM_GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n", i, node->ne[0], node->ne[1], node->ne[2], - lm_ggml_op_name(node->op), (node->flags & LM_GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " ", node->perf_runs, - (double) node->perf_cycles / (double) lm_ggml_cycles_per_ms(), - (double) node->perf_cycles / (double) lm_ggml_cycles_per_ms() / (double) node->perf_runs, - (double) node->perf_time_us / 1000.0, - (double) node->perf_time_us / 1000.0 / node->perf_runs); + lm_ggml_op_name(node->op), (node->flags & LM_GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " "); } LM_GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs); @@ -20275,14 +19421,6 @@ void lm_ggml_graph_print(const struct lm_ggml_cgraph * cgraph) { lm_ggml_get_name(node)); } - for (int i = 0; i < LM_GGML_OP_COUNT; i++) { - if (perf_total_per_op_us[i] == 0) { - continue; - } - - LM_GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", lm_ggml_op_name(i), (double) perf_total_per_op_us[i] / 1000.0); - } - LM_GGML_PRINT("========================================\n"); } @@ -20341,7 +19479,7 @@ void lm_ggml_graph_dump_dot(const struct lm_ggml_cgraph * gb, const struct lm_gg fprintf(fp, "digraph G {\n"); fprintf(fp, " newrank = true;\n"); - fprintf(fp, " rankdir = LR;\n"); + fprintf(fp, " rankdir = TB;\n"); for (int i = 0; i < gb->n_nodes; i++) { struct lm_ggml_tensor * node = gb->nodes[i]; @@ -20403,7 +19541,7 @@ void lm_ggml_graph_dump_dot(const struct lm_ggml_cgraph * gb, const struct lm_gg } fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]); - if (lm_ggml_nelements(node) < 5) { + if (lm_ggml_nelements(node) < 5 && node->data != NULL) { fprintf(fp, " | ("); for (int j = 0; j < lm_ggml_nelements(node); j++) { if (node->type == LM_GGML_TYPE_I8 || node->type == LM_GGML_TYPE_I16 || node->type == LM_GGML_TYPE_I32) { @@ -21459,6 +20597,9 @@ size_t lm_ggml_quantize_chunk( case LM_GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case LM_GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case LM_GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case LM_GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case LM_GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case LM_GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case LM_GGML_TYPE_F16: { size_t elemsize = sizeof(lm_ggml_fp16_t); @@ -21689,6 +20830,7 @@ struct lm_gguf_context * lm_gguf_init_empty(void) { struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gguf_init_params params) { FILE * file = lm_ggml_fopen(fname, "rb"); if (!file) { + fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno)); return NULL; } @@ -21922,8 +21064,8 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg (int64_t) info->ne[3]; if (ne % lm_ggml_blck_size(info->type) != 0) { - fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%d)\n", - __func__, info->name.data, (int)info->type, lm_ggml_type_name(info->type), ne, lm_ggml_blck_size(info->type)); + fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n", + __func__, info->name.data, (int) info->type, lm_ggml_type_name(info->type), ne, lm_ggml_blck_size(info->type)); fclose(file); lm_gguf_free(ctx); return NULL; @@ -22761,8 +21903,6 @@ int lm_ggml_cpu_has_neon(void) { int lm_ggml_cpu_has_sve(void) { #if defined(__ARM_FEATURE_SVE) - // TODO: Currently, SVE 256 bit is only supported. - LM_GGML_ASSERT(svcntb() == QK8_0); return 1; #else return 0; @@ -22810,7 +21950,7 @@ int lm_ggml_cpu_has_wasm_simd(void) { } int lm_ggml_cpu_has_blas(void) { -#if defined(LM_GGML_USE_ACCELERATE) || defined(LM_GGML_USE_OPENBLAS) || defined(LM_GGML_USE_CUDA) || defined(LM_GGML_USE_VULKAN) || defined(LM_GGML_USE_CLBLAST) || defined(LM_GGML_USE_SYCL) +#if defined(LM_GGML_USE_BLAS) || defined(LM_GGML_USE_CUDA) || defined(LM_GGML_USE_VULKAN) || defined(LM_GGML_USE_SYCL) return 1; #else return 0; @@ -22825,14 +21965,6 @@ int lm_ggml_cpu_has_cuda(void) { #endif } -int lm_ggml_cpu_has_clblast(void) { -#if defined(LM_GGML_USE_CLBLAST) - return 1; -#else - return 0; -#endif -} - int lm_ggml_cpu_has_vulkan(void) { #if defined(LM_GGML_USE_VULKAN) return 1; @@ -22857,9 +21989,24 @@ int lm_ggml_cpu_has_sycl(void) { #endif } +int lm_ggml_cpu_has_rpc(void) { +#if defined(LM_GGML_USE_RPC) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_cann(void) { +#if defined(LM_GGML_USE_CANN) + return 1; +#else + return 0; +#endif +} + int lm_ggml_cpu_has_gpublas(void) { - return lm_ggml_cpu_has_cuda() || lm_ggml_cpu_has_clblast() || lm_ggml_cpu_has_vulkan() || lm_ggml_cpu_has_kompute() || - lm_ggml_cpu_has_sycl(); + return lm_ggml_cpu_has_cuda() || lm_ggml_cpu_has_vulkan() || lm_ggml_cpu_has_kompute() || lm_ggml_cpu_has_sycl(); } int lm_ggml_cpu_has_sse3(void) { diff --git a/cpp/ggml.h b/cpp/ggml.h index 1c5d9f31..53fbf114 100644 --- a/cpp/ggml.h +++ b/cpp/ggml.h @@ -312,6 +312,12 @@ LM_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ LM_GGML_TENSOR_LOCALS(size_t, nb, dst, nb) +#define LM_GGML_TENSOR_BINARY_OP_LOCALS01 \ + LM_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + LM_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + LM_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ + LM_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) + #ifdef __cplusplus extern "C" { #endif @@ -377,6 +383,9 @@ extern "C" { LM_GGML_TYPE_F64 = 28, LM_GGML_TYPE_IQ1_M = 29, LM_GGML_TYPE_BF16 = 30, + LM_GGML_TYPE_Q4_0_4_4 = 31, + LM_GGML_TYPE_Q4_0_4_8 = 32, + LM_GGML_TYPE_Q4_0_8_8 = 33, LM_GGML_TYPE_COUNT, }; @@ -418,6 +427,9 @@ extern "C" { LM_GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors LM_GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors LM_GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors + LM_GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors + LM_GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors + LM_GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors }; // available tensor operations: @@ -585,11 +597,7 @@ extern "C" { struct lm_ggml_tensor * grad; struct lm_ggml_tensor * src[LM_GGML_MAX_SRC]; - // performance - int perf_runs; - int64_t perf_cycles; - int64_t perf_time_us; - + // source tensor and offset for views struct lm_ggml_tensor * view_src; size_t view_offs; @@ -599,7 +607,7 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - char padding[8]; + // char padding[4]; }; static const size_t LM_GGML_TENSOR_SIZE = sizeof(struct lm_ggml_tensor); @@ -646,11 +654,6 @@ extern "C" { struct lm_ggml_hash_set visited_hash_table; enum lm_ggml_cgraph_eval_order order; - - // performance - int perf_runs; - int64_t perf_cycles; - int64_t perf_time_us; }; // scratch buffer @@ -667,28 +670,6 @@ extern "C" { bool no_alloc; // don't allocate memory for the tensor data }; - - // compute types - - // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled. - // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995. - enum lm_ggml_task_type { - LM_GGML_TASK_TYPE_INIT = 0, - LM_GGML_TASK_TYPE_COMPUTE, - LM_GGML_TASK_TYPE_FINALIZE, - }; - - struct lm_ggml_compute_params { - enum lm_ggml_task_type type; - - // ith = thread index, nth = number of threads - int ith, nth; - - // work buffer for all threads - size_t wsize; - void * wdata; - }; - // numa strategies enum lm_ggml_numa_strategy { LM_GGML_NUMA_STRATEGY_DISABLED = 0, @@ -733,9 +714,9 @@ extern "C" { LM_GGML_API LM_GGML_CALL size_t lm_ggml_nbytes (const struct lm_ggml_tensor * tensor); LM_GGML_API size_t lm_ggml_nbytes_pad (const struct lm_ggml_tensor * tensor); // same as lm_ggml_nbytes() but padded to LM_GGML_MEM_ALIGN - LM_GGML_API LM_GGML_CALL int lm_ggml_blck_size(enum lm_ggml_type type); - LM_GGML_API LM_GGML_CALL size_t lm_ggml_type_size(enum lm_ggml_type type); // size in bytes for all elements in a block - LM_GGML_API LM_GGML_CALL size_t lm_ggml_row_size (enum lm_ggml_type type, int64_t ne); // size in bytes for all elements in a row + LM_GGML_API LM_GGML_CALL int64_t lm_ggml_blck_size(enum lm_ggml_type type); + LM_GGML_API LM_GGML_CALL size_t lm_ggml_type_size(enum lm_ggml_type type); // size in bytes for all elements in a block + LM_GGML_API LM_GGML_CALL size_t lm_ggml_row_size (enum lm_ggml_type type, int64_t ne); // size in bytes for all elements in a row LM_GGML_DEPRECATED( LM_GGML_API double lm_ggml_type_sizef(enum lm_ggml_type type), // lm_ggml_type_size()/lm_ggml_blck_size() as float @@ -756,7 +737,6 @@ extern "C" { LM_GGML_API enum lm_ggml_type lm_ggml_ftype_to_lm_ggml_type(enum lm_ggml_ftype ftype); LM_GGML_API LM_GGML_CALL bool lm_ggml_is_transposed(const struct lm_ggml_tensor * tensor); - LM_GGML_API LM_GGML_CALL bool lm_ggml_is_contiguous(const struct lm_ggml_tensor * tensor); LM_GGML_API LM_GGML_CALL bool lm_ggml_is_permuted (const struct lm_ggml_tensor * tensor); LM_GGML_API LM_GGML_CALL bool lm_ggml_is_empty (const struct lm_ggml_tensor * tensor); LM_GGML_API bool lm_ggml_is_scalar (const struct lm_ggml_tensor * tensor); @@ -765,9 +745,16 @@ extern "C" { LM_GGML_API bool lm_ggml_is_3d (const struct lm_ggml_tensor * tensor); LM_GGML_API int lm_ggml_n_dims (const struct lm_ggml_tensor * tensor); // returns 1 for scalars + LM_GGML_API LM_GGML_CALL bool lm_ggml_is_contiguous (const struct lm_ggml_tensor * tensor); + LM_GGML_API LM_GGML_CALL bool lm_ggml_is_contiguous_0(const struct lm_ggml_tensor * tensor); // same as lm_ggml_is_contiguous() + LM_GGML_API LM_GGML_CALL bool lm_ggml_is_contiguous_1(const struct lm_ggml_tensor * tensor); // contiguous for dims >= 1 + LM_GGML_API LM_GGML_CALL bool lm_ggml_is_contiguous_2(const struct lm_ggml_tensor * tensor); // contiguous for dims >= 2 + LM_GGML_API bool lm_ggml_are_same_shape (const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1); LM_GGML_API bool lm_ggml_are_same_stride(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1); + LM_GGML_API bool lm_ggml_can_repeat(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1); + // use this to compute the memory overhead of a tensor LM_GGML_API size_t lm_ggml_tensor_overhead(void); @@ -1461,7 +1448,6 @@ extern "C" { // rotary position embedding // if mode & 1 == 1, skip n_past elements (NOT SUPPORTED) // if mode & 2 == 1, GPT-NeoX style - // if mode & 4 == 1, ChatGLM style // // b is an int32 vector with size a->ne[2], it contains the positions // c is freq factors (e.g. phi3-128k), (optional) @@ -1470,8 +1456,7 @@ extern "C" { struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, int n_dims, - int mode, - int n_ctx); + int mode); // in-place, returns view(a) LM_GGML_API struct lm_ggml_tensor * lm_ggml_rope_inplace( @@ -1479,8 +1464,7 @@ extern "C" { struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, int n_dims, - int mode, - int n_ctx); + int mode); // custom RoPE LM_GGML_API struct lm_ggml_tensor * lm_ggml_rope_ext( @@ -1490,8 +1474,7 @@ extern "C" { struct lm_ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -1507,8 +1490,7 @@ extern "C" { struct lm_ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -1522,8 +1504,7 @@ extern "C" { struct lm_ggml_tensor * b, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -1538,8 +1519,7 @@ extern "C" { struct lm_ggml_tensor * b, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -1550,7 +1530,7 @@ extern "C" { // compute correction dims for YaRN RoPE scaling LM_GGML_CALL void lm_ggml_rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]); // rotary position embedding backward, i.e compute dx from dy // a - dy @@ -1561,16 +1541,13 @@ extern "C" { struct lm_ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, - float beta_slow, - float xpos_base, - bool xpos_down); + float beta_slow); // clamp // in-place, returns view(a) @@ -2413,15 +2390,16 @@ extern "C" { LM_GGML_API int lm_ggml_cpu_has_wasm_simd (void); LM_GGML_API int lm_ggml_cpu_has_blas (void); LM_GGML_API int lm_ggml_cpu_has_cuda (void); - LM_GGML_API int lm_ggml_cpu_has_clblast (void); LM_GGML_API int lm_ggml_cpu_has_vulkan (void); LM_GGML_API int lm_ggml_cpu_has_kompute (void); LM_GGML_API int lm_ggml_cpu_has_gpublas (void); LM_GGML_API int lm_ggml_cpu_has_sse3 (void); LM_GGML_API int lm_ggml_cpu_has_ssse3 (void); LM_GGML_API int lm_ggml_cpu_has_sycl (void); + LM_GGML_API int lm_ggml_cpu_has_rpc (void); LM_GGML_API int lm_ggml_cpu_has_vsx (void); LM_GGML_API int lm_ggml_cpu_has_matmul_int8(void); + LM_GGML_API int lm_ggml_cpu_has_cann (void); // // Internal types and functions exposed for tests and benchmarks @@ -2435,20 +2413,31 @@ extern "C" { #endif typedef void (*lm_ggml_to_float_t) (const void * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); typedef void (*lm_ggml_from_float_t)(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); - typedef void (*lm_ggml_vec_dot_t) (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT x, size_t bx, - const void * LM_GGML_RESTRICT y, size_t by, int nrc); + typedef void (*lm_ggml_from_float_to_mat_t) + (const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs); + typedef void (*lm_ggml_vec_dot_t) (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT x, size_t bx, + const void * LM_GGML_RESTRICT y, size_t by, int nrc); + typedef void (*lm_ggml_gemv_t) (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT x, + const void * LM_GGML_RESTRICT y, int nr, int nc); + typedef void (*lm_ggml_gemm_t) (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT x, + const void * LM_GGML_RESTRICT y, int nr, int nc); typedef struct { - const char * type_name; - int blck_size; - size_t type_size; - bool is_quantized; - lm_ggml_to_float_t to_float; - lm_ggml_from_float_t from_float; - lm_ggml_from_float_t from_float_reference; - lm_ggml_vec_dot_t vec_dot; - enum lm_ggml_type vec_dot_type; - int64_t nrows; // number of rows to process simultaneously; + const char * type_name; + int64_t blck_size; + int64_t blck_size_interleave; // interleave elements in blocks + size_t type_size; + bool is_quantized; + lm_ggml_to_float_t to_float; + lm_ggml_from_float_t from_float; + lm_ggml_from_float_t from_float_ref; + lm_ggml_from_float_to_mat_t from_float_to_mat; + lm_ggml_vec_dot_t vec_dot; + enum lm_ggml_type vec_dot_type; + int64_t nrows; // number of rows to process simultaneously + int64_t ncols; // number of columns to process simultaneously + lm_ggml_gemv_t gemv; + lm_ggml_gemm_t gemm; } lm_ggml_type_traits_t; LM_GGML_API lm_ggml_type_traits_t lm_ggml_internal_get_type_traits(enum lm_ggml_type type); diff --git a/cpp/grammar-parser.cpp b/cpp/grammar-parser.cpp index b5bc7d49..a518b766 100644 --- a/cpp/grammar-parser.cpp +++ b/cpp/grammar-parser.cpp @@ -46,8 +46,12 @@ namespace grammar_parser { state.rules[rule_id] = rule; } + static bool is_digit_char(char c) { + return '0' <= c && c <= '9'; + } + static bool is_word_char(char c) { - return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); } static std::pair parse_hex(const char * src, int size) { @@ -99,6 +103,17 @@ namespace grammar_parser { return pos; } + static const char * parse_int(const char * src) { + const char * pos = src; + while (is_digit_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting integer at ") + src); + } + return pos; + } + static std::pair parse_char(const char * src) { if (*src == '\\') { switch (src[1]) { @@ -137,6 +152,60 @@ namespace grammar_parser { bool is_nested) { size_t last_sym_start = out_elements.size(); const char * pos = src; + + auto handle_repetitions = [&](int min_times, int max_times) { + + if (last_sym_start == out_elements.size()) { + throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); + } + + // apply transformation to previous symbol (last_sym_start to end) according to + // the following rewrite rules: + // S{m,n} --> S S S (m times) S'(n-m) + // S'(x) ::= S S'(x-1) | + // (... n-m definitions of these S' rules ...) + // S'(1) ::= S | + // S{m,} --> S S S (m times) S' + // S' ::= S S' | + // S* --> S{0,} + // --> S' ::= S S' | + // S+ --> S{1,} + // --> S S' + // S' ::= S S' | + // S? --> S{0,1} + // --> S' + // S' ::= S | + + std::vector previous_elements(out_elements.begin() + last_sym_start, out_elements.end()); + if (min_times == 0) { + out_elements.resize(last_sym_start); + } else { + // Repeat the previous elements (min_times - 1) times + for (int i = 1; i < min_times; i++) { + out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()); + } + } + + uint32_t last_rec_rule_id = 0; + auto n_opt = max_times < 0 ? 1 : max_times - min_times; + + std::vector rec_rule(previous_elements); + for (int i = 0; i < n_opt; i++) { + rec_rule.resize(previous_elements.size()); + uint32_t rec_rule_id = generate_symbol_id(state, rule_name); + if (i > 0 || max_times < 0) { + rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); + } + rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + rec_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, rec_rule_id, rec_rule); + last_rec_rule_id = rec_rule_id; + } + if (n_opt > 0) { + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); + } + }; + while (*pos) { if (*pos == '"') { // literal string pos++; @@ -197,40 +266,51 @@ namespace grammar_parser { throw std::runtime_error(std::string("expecting ')' at ") + pos); } pos = parse_space(pos + 1, is_nested); - } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator - if (last_sym_start == out_elements.size()) { - throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos); - } + } else if (*pos == '.') { // any char + last_sym_start = out_elements.size(); + out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '*') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, -1); + } else if (*pos == '+') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(1, -1); + } else if (*pos == '?') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, 1); + } else if (*pos == '{') { + pos = parse_space(pos + 1, is_nested); - // apply transformation to previous symbol (last_sym_start to end) according to - // rewrite rules: - // S* --> S' ::= S S' | - // S+ --> S' ::= S S' | S - // S? --> S' ::= S | - uint32_t sub_rule_id = generate_symbol_id(state, rule_name); - std::vector sub_rule; - // add preceding symbol to generated rule - sub_rule.insert( - sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); - if (*pos == '*' || *pos == '+') { - // cause generated rule to recurse - sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + if (!is_digit_char(*pos)) { + throw std::runtime_error(std::string("expecting an int at ") + pos); } - // mark start of alternate def - sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - if (*pos == '+') { - // add preceding symbol as alternate only for '+' (otherwise empty) - sub_rule.insert( - sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); - } - sub_rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, sub_rule_id, sub_rule); + const char * int_end = parse_int(pos); + int min_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); - // in original rule, replace previous symbol with reference to generated rule - out_elements.resize(last_sym_start); - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + int max_times = -1; - pos = parse_space(pos + 1, is_nested); + if (*pos == '}') { + max_times = min_times; + pos = parse_space(pos + 1, is_nested); + } else if (*pos == ',') { + pos = parse_space(pos + 1, is_nested); + + if (is_digit_char(*pos)) { + const char * int_end = parse_int(pos); + max_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + } + + if (*pos != '}') { + throw std::runtime_error(std::string("expecting '}' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else { + throw std::runtime_error(std::string("expecting ',' at ") + pos); + } + handle_repetitions(min_times, max_times); } else { break; } @@ -325,6 +405,7 @@ namespace grammar_parser { case LLAMA_GRETYPE_CHAR_NOT: return true; case LLAMA_GRETYPE_CHAR_ALT: return true; case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; + case LLAMA_GRETYPE_CHAR_ANY: return true; default: return false; } } @@ -339,6 +420,7 @@ namespace grammar_parser { case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; + case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; } switch (elem.type) { case LLAMA_GRETYPE_END: @@ -350,6 +432,7 @@ namespace grammar_parser { case LLAMA_GRETYPE_CHAR_NOT: case LLAMA_GRETYPE_CHAR_RNG_UPPER: case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "(\""); print_grammar_char(file, elem.value); fprintf(file, "\") "); @@ -407,11 +490,15 @@ namespace grammar_parser { } print_grammar_char(file, elem.value); break; + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "."); + break; } if (is_char_element(elem)) { switch (rule[i + 1].type) { case LLAMA_GRETYPE_CHAR_ALT: case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ANY: break; default: fprintf(file, "] "); diff --git a/cpp/json-schema-to-grammar.cpp b/cpp/json-schema-to-grammar.cpp index 9a71f5d8..881eb49e 100644 --- a/cpp/json-schema-to-grammar.cpp +++ b/cpp/json-schema-to-grammar.cpp @@ -16,92 +16,282 @@ static std::string join(Iterator begin, Iterator end, const std::string & separa static std::string repeat(const std::string & str, size_t n); -static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "", bool item_rule_is_literal = false) { +static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { + auto has_max = max_items != std::numeric_limits::max(); + + if (min_items == 0 && max_items == 1) { + return item_rule + "?"; + } + if (separator_rule.empty()) { - if (min_items == 0 && max_items == 1) { - return item_rule + "?"; - } else if (min_items == 1 && max_items == std::numeric_limits::max()) { + if (min_items == 1 && !has_max) { return item_rule + "+"; + } else if (min_items == 0 && !has_max) { + return item_rule + "*"; + } else { + return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}"; } } - std::string result; - if (min_items > 0) { - if (item_rule_is_literal && separator_rule.empty()) { - result = "\"" + repeat(std::string(item_rule.begin() + 1, item_rule.end() - 1), min_items) + "\""; - } else { - std::vector items(min_items, item_rule); - result = join(items.begin(), items.end(), separator_rule.empty() ? " " : " " + separator_rule + " "); + auto result = item_rule + " " + build_repetition("(" + separator_rule + " " + item_rule + ")", min_items == 0 ? 0 : min_items - 1, has_max ? max_items - 1 : max_items); + if (min_items == 0) { + result = "(" + result + ")?"; + } + return result; +} + +/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */ +class string_view { + const std::string & _str; + const size_t _start; + const size_t _end; +public: + string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {} + + size_t size() const { + return _end - _start; + } + + size_t length() const { + return size(); + } + + operator std::string() const { + return str(); + } + + std::string str() const { + return _str.substr(_start, _end - _start); + } + + string_view substr(size_t pos, size_t len = std::string::npos) const { + return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len); + } + + char operator[](size_t pos) const { + auto index = _start + pos; + if (index >= _end) { + throw std::out_of_range("string_view index out of range"); } + return _str[_start + pos]; } - std::function opt_repetitions = [&](int up_to_n, bool prefix_with_sep) -> std::string { - auto content = prefix_with_sep && !separator_rule.empty() ? separator_rule + " " + item_rule : item_rule; + bool operator==(const string_view & other) const { + std::string this_str = *this; + std::string other_str = other; + return this_str == other_str; + } +}; - if (up_to_n == 0) { - return ""; - } else if (up_to_n == 1) { - return "(" + content + ")?"; - } else if (!separator_rule.empty() && !prefix_with_sep) { - return "(" + content + " " + opt_repetitions(up_to_n - 1, true) + ")?"; +static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { + auto has_min = min_value != std::numeric_limits::min(); + auto has_max = max_value != std::numeric_limits::max(); + + auto digit_range = [&](char from, char to) { + out << "["; + if (from == to) { + out << from; } else { - std::string res = repeat("(" + content + " ", up_to_n); - // strip trailing space - res = res.substr(0, res.length() - 1); - res += repeat(")?", up_to_n); - return res; + out << from << "-" << to; } + out << "]"; }; + auto more_digits = [&](int min_digits, int max_digits) { + out << "[0-9]"; + if (min_digits == max_digits && min_digits == 1) { + return; + } + out << "{"; + out << min_digits; + if (max_digits != min_digits) { + out << ","; + if (max_digits != std::numeric_limits::max()) { + out << max_digits; + } + } + out << "}"; + }; + std::function uniform_range = + [&](const string_view & from, const string_view & to) { + size_t i = 0; + while (i < from.length() && i < to.length() && from[i] == to[i]) { + i++; + } + if (i > 0) { + out << "\"" << from.substr(0, i).str() << "\""; + } + if (i < from.length() && i < to.length()) { + if (i > 0) { + out << " "; + } + auto sub_len = from.length() - i - 1; + if (sub_len > 0) { + auto from_sub = from.substr(i + 1); + auto to_sub = to.substr(i + 1); + auto sub_zeros = repeat("0", sub_len); + auto sub_nines = repeat("9", sub_len); + + auto to_reached = false; + out << "("; + if (from_sub == sub_zeros) { + digit_range(from[i], to[i] - 1); + out << " "; + more_digits(sub_len, sub_len); + } else { + out << "[" << from[i] << "] "; + out << "("; + uniform_range(from_sub, sub_nines); + out << ")"; + if (from[i] < to[i] - 1) { + out << " | "; + if (to_sub == sub_nines) { + digit_range(from[i] + 1, to[i]); + to_reached = true; + } else { + digit_range(from[i] + 1, to[i] - 1); + } + out << " "; + more_digits(sub_len, sub_len); + } + } + if (!to_reached) { + out << " | "; + digit_range(to[i], to[i]); + out << " "; + uniform_range(sub_zeros, to_sub); + } + out << ")"; + } else { + out << "[" << from[i] << "-" << to[i] << "]"; + } + } + }; + + if (has_min && has_max) { + if (min_value < 0 && max_value < 0) { + out << "\"-\" ("; + _build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true); + out << ")"; + return; + } + + if (min_value < 0) { + out << "\"-\" ("; + _build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true); + out << ") | "; + min_value = 0; + } + + auto min_s = std::to_string(min_value); + auto max_s = std::to_string(max_value); + auto min_digits = min_s.length(); + auto max_digits = max_s.length(); - if (min_items > 0 && max_items != min_items) { - result += " "; + for (auto digits = min_digits; digits < max_digits; digits++) { + uniform_range(min_s, repeat("9", digits)); + min_s = "1" + repeat("0", digits); + out << " | "; + } + uniform_range(min_s, max_s); + return; + } + + auto less_decimals = std::max(decimals_left - 1, 1); + + if (has_min) { + if (min_value < 0) { + out << "\"-\" ("; + _build_min_max_int(std::numeric_limits::min(), -min_value, out, decimals_left, /* top_level= */ false); + out << ") | [0] | [1-9] "; + more_digits(0, decimals_left - 1); + } else if (min_value == 0) { + if (top_level) { + out << "[0] | [1-9] "; + more_digits(0, less_decimals); + } else { + more_digits(1, decimals_left); + } + } else if (min_value <= 9) { + char c = '0' + min_value; + auto range_start = top_level ? '1' : '0'; + if (c > range_start) { + digit_range(range_start, c - 1); + out << " "; + more_digits(1, less_decimals); + out << " | "; + } + digit_range(c, '9'); + out << " "; + more_digits(0, less_decimals); + } else { + auto min_s = std::to_string(min_value); + auto len = min_s.length(); + auto c = min_s[0]; + + if (c > '1') { + digit_range(top_level ? '1' : '0', c - 1); + out << " "; + more_digits(len, less_decimals); + out << " | "; + } + digit_range(c, c); + out << " ("; + _build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits::max(), out, less_decimals, /* top_level= */ false); + out << ")"; + if (c < '9') { + out << " | "; + digit_range(c + 1, '9'); + out << " "; + more_digits(len - 1, less_decimals); + } + } + return; } - if (max_items != std::numeric_limits::max()) { - result += opt_repetitions(max_items - min_items, min_items > 0); - } else { - std::string item_operator = "(" + (separator_rule.empty() ? "" : separator_rule + " ") + item_rule + ")"; - if (min_items == 0 && !separator_rule.empty()) { - result = "(" + item_rule + " " + item_operator + "*)?"; + if (has_max) { + if (max_value >= 0) { + if (top_level) { + out << "\"-\" [1-9] "; + more_digits(0, less_decimals); + out << " | "; + } + _build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true); } else { - result += item_operator + "*"; + out << "\"-\" ("; + _build_min_max_int(-max_value, std::numeric_limits::max(), out, decimals_left, /* top_level= */ false); + out << ")"; } + return; } - return result; + throw std::runtime_error("At least one of min_value or max_value must be set"); } -const std::string SPACE_RULE = "\" \"?"; +const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}"; struct BuiltinRule { std::string content; std::vector deps; }; -const std::string _up_to_15_digits = build_repetition("[0-9]", 0, 15); - std::unordered_map PRIMITIVE_RULES = { {"boolean", {"(\"true\" | \"false\") space", {}}}, - {"decimal-part", {"[0-9] " + _up_to_15_digits, {}}}, - {"integral-part", {"[0-9] | [1-9] " + _up_to_15_digits, {}}}, + {"decimal-part", {"[0-9]{1,16}", {}}}, + {"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}}, {"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}}, {"integer", {"(\"-\"? integral-part) space", {"integral-part"}}}, {"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}}, {"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}}, {"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}}, - {"uuid", {"\"\\\"\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " - "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " - "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " - "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " - "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] \"\\\"\" space", {}}}, - {"char", {"[^\"\\\\] | \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])", {}}}, + {"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}}, + {"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}}, {"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}}, {"null", {"\"null\" space", {}}}, }; std::unordered_map STRING_FORMAT_RULES = { - {"date", {"[0-9] [0-9] [0-9] [0-9] \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}}, - {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9] [0-9] [0-9] )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}}, + {"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}}, + {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}}, {"date-time", {"date \"T\" time", {"date", "time"}}}, {"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}}, {"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}}, @@ -126,7 +316,7 @@ std::unordered_map GRAMMAR_LITERAL_ESCAPES = { }; std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; -std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; +std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; template std::string join(Iterator begin, Iterator end, const std::string & separator) { @@ -197,7 +387,6 @@ static std::string format_literal(const std::string & literal) { return "\"" + escaped + "\""; } - class SchemaConverter { private: std::function _fetch_json; @@ -385,8 +574,7 @@ class SchemaConverter { sub_is_literal ? "\"" + sub + "\"" : sub, min_times, max_times, - "", - sub_is_literal + "" ); seq.back().second = false; } else { @@ -426,6 +614,75 @@ class SchemaConverter { return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space"); } + /* + Returns a rule that matches a JSON string that is none of the provided strings + + not_strings({"a"}) + -> ["] ( [a] char+ | [^"a] char* )? ["] space + not_strings({"and", "also"}) + -> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space + */ + std::string _not_strings(const std::vector & strings) { + + struct TrieNode { + std::map children; + bool is_end_of_string; + + TrieNode() : is_end_of_string(false) {} + + void insert(const std::string & string) { + auto node = this; + for (char c : string) { + node = &node->children[c]; + } + node->is_end_of_string = true; + } + }; + + TrieNode trie; + for (const auto & s : strings) { + trie.insert(s); + } + + std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char")); + std::ostringstream out; + out << "[\"] ( "; + std::function visit = [&](const TrieNode & node) { + std::ostringstream rejects; + auto first = true; + for (const auto & kv : node.children) { + rejects << kv.first; + if (first) { + first = false; + } else { + out << " | "; + } + out << "[" << kv.first << "]"; + if (!kv.second.children.empty()) { + out << " ("; + visit(kv.second); + out << ")"; + } else if (kv.second.is_end_of_string) { + out << " " << char_rule << "+"; + } + } + if (!node.children.empty()) { + if (!first) { + out << " | "; + } + out << "[^\"" << rejects.str() << "] " << char_rule << "*"; + } + }; + visit(trie); + + out << " )"; + if (!trie.is_end_of_string) { + out << "?"; + } + out << " [\"] space"; + return out.str(); + } + std::string _resolve_ref(const std::string & ref) { std::string ref_name = ref.substr(ref.find_last_of('/') + 1); if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) { @@ -446,6 +703,7 @@ class SchemaConverter { std::vector required_props; std::vector optional_props; std::unordered_map prop_kv_rule_names; + std::vector prop_names; for (const auto & kv : properties) { const auto &prop_name = kv.first; const auto &prop_schema = kv.second; @@ -460,11 +718,18 @@ class SchemaConverter { } else { optional_props.push_back(prop_name); } + prop_names.push_back(prop_name); } - if (additional_properties.is_object() || (additional_properties.is_boolean() && additional_properties.get())) { + if ((additional_properties.is_boolean() && additional_properties.get()) || additional_properties.is_object()) { std::string sub_name = name + (name.empty() ? "" : "-") + "additional"; - std::string value_rule = visit(additional_properties.is_object() ? additional_properties : json::object(), sub_name + "-value"); - std::string kv_rule = _add_rule(sub_name + "-kv", _add_primitive("string", PRIMITIVE_RULES.at("string")) + " \":\" space " + value_rule); + std::string value_rule = + additional_properties.is_object() ? visit(additional_properties, sub_name + "-value") + : _add_primitive("value", PRIMITIVE_RULES.at("value")); + + auto key_rule = + prop_names.empty() ? _add_primitive("string", PRIMITIVE_RULES.at("string")) + : _add_rule(sub_name + "-k", _not_strings(prop_names)); + std::string kv_rule = _add_rule(sub_name + "-kv", key_rule + " \":\" space " + value_rule); prop_kv_rule_names["*"] = kv_rule; optional_props.push_back("*"); } @@ -490,15 +755,11 @@ class SchemaConverter { } std::string k = ks[0]; std::string kv_rule_name = prop_kv_rule_names[k]; - if (k == "*") { - res = _add_rule( - name + (name.empty() ? "" : "-") + "additional-kvs", - kv_rule_name + " ( \",\" space " + kv_rule_name + " )*" - ); - } else if (first_is_optional) { - res = "( \",\" space " + kv_rule_name + " )?"; + std::string comma_ref = "( \",\" space " + kv_rule_name + " )"; + if (first_is_optional) { + res = comma_ref + (k == "*" ? "*" : "?"); } else { - res = kv_rule_name; + res = kv_rule_name + (k == "*" ? " " + comma_ref + "*" : ""); } if (ks.size() > 1) { res += " " + _add_rule( @@ -632,17 +893,19 @@ class SchemaConverter { } else if (schema_type.is_array()) { std::vector schema_types; for (const auto & t : schema_type) { - schema_types.push_back({{"type", t}}); + json schema_copy(schema); + schema_copy["type"] = t; + schema_types.push_back(schema_copy); } return _add_rule(rule_name, _generate_union_rule(name, schema_types)); } else if (schema.contains("const")) { - return _add_rule(rule_name, _generate_constant_rule(schema["const"])); + return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space"); } else if (schema.contains("enum")) { std::vector enum_values; for (const auto & v : schema["enum"]) { enum_values.push_back(_generate_constant_rule(v)); } - return _add_rule(rule_name, join(enum_values.begin(), enum_values.end(), " | ")); + return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space"); } else if ((schema_type.is_null() || schema_type == "object") && (schema.contains("properties") || (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { @@ -724,6 +987,24 @@ class SchemaConverter { int min_len = schema.contains("minLength") ? schema["minLength"].get() : 0; int max_len = schema.contains("maxLength") ? schema["maxLength"].get() : std::numeric_limits::max(); return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space"); + } else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) { + int min_value = std::numeric_limits::min(); + int max_value = std::numeric_limits::max(); + if (schema.contains("minimum")) { + min_value = schema["minimum"].get(); + } else if (schema.contains("exclusiveMinimum")) { + min_value = schema["exclusiveMinimum"].get() + 1; + } + if (schema.contains("maximum")) { + max_value = schema["maximum"].get(); + } else if (schema.contains("exclusiveMaximum")) { + max_value = schema["exclusiveMaximum"].get() - 1; + } + std::stringstream out; + out << "("; + _build_min_max_int(min_value, max_value, out); + out << ") space"; + return _add_rule(rule_name, out.str()); } else if (schema.empty() || schema_type == "object") { return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object"))); } else { diff --git a/cpp/llama.cpp b/cpp/llama.cpp index 5084bee1..5883887a 100644 --- a/cpp/llama.cpp +++ b/cpp/llama.cpp @@ -13,14 +13,18 @@ #ifdef LM_GGML_USE_CUDA # include "ggml-cuda.h" -#elif defined(LM_GGML_USE_CLBLAST) -# include "ggml-opencl.h" #elif defined(LM_GGML_USE_VULKAN) # include "ggml-vulkan.h" #elif defined(LM_GGML_USE_SYCL) # include "ggml-sycl.h" #elif defined(LM_GGML_USE_KOMPUTE) # include "ggml-kompute.h" +#elif defined(LM_GGML_USE_CANN) +# include "ggml-cann.h" +#endif + +#ifdef LM_GGML_USE_BLAS +# include "ggml-blas.h" #endif #ifdef LM_GGML_USE_METAL @@ -55,6 +59,12 @@ #include #endif +#if __cplusplus >= 202000L + #define LU8(x) (const char*)(u8##x) +#else + #define LU8(x) u8##x +#endif + #include #include #include @@ -102,15 +112,17 @@ #define LLAMA_ATTRIBUTE_FORMAT(...) #endif +// bump if necessary #define LLAMA_MAX_NODES 8192 -#define LLAMA_MAX_EXPERTS 160 +#define LLAMA_MAX_LAYERS 512 +#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2 // // logging // LLAMA_ATTRIBUTE_FORMAT(2, 3) -static void llama_log_internal (lm_ggml_log_level level, const char* format, ...); +static void llama_log_internal (lm_ggml_log_level level, const char * format, ...); static void llama_log_callback_default(lm_ggml_log_level level, const char * text, void * user_data); #define LLAMA_LOG_INFO(...) llama_log_internal(LM_GGML_LOG_LEVEL_INFO , __VA_ARGS__) @@ -226,14 +238,20 @@ enum llm_arch { LLM_ARCH_INTERNLM2, LLM_ARCH_MINICPM, LLM_ARCH_GEMMA, + LLM_ARCH_GEMMA2, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, LLM_ARCH_OLMO, + LLM_ARCH_OPENELM, LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK2, + LLM_ARCH_CHATGLM, + LLM_ARCH_BITNET, + LLM_ARCH_T5, + LLM_ARCH_JAIS, LLM_ARCH_UNKNOWN, }; @@ -264,18 +282,25 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_INTERNLM2, "internlm2" }, { LLM_ARCH_MINICPM, "minicpm" }, { LLM_ARCH_GEMMA, "gemma" }, + { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_OPENELM, "openelm" }, { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_CHATGLM, "chatglm" }, + { LLM_ARCH_BITNET, "bitnet" }, + { LLM_ARCH_T5, "t5" }, + { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; enum llm_kv { + LLM_KV_GENERAL_TYPE, LLM_KV_GENERAL_ARCHITECTURE, LLM_KV_GENERAL_QUANTIZATION_VERSION, LLM_KV_GENERAL_ALIGNMENT, @@ -295,6 +320,7 @@ enum llm_kv { LLM_KV_LEADING_DENSE_BLOCK_COUNT, LLM_KV_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, @@ -303,6 +329,9 @@ enum llm_kv { LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, + LLM_KV_DECODER_START_TOKEN_ID, + LLM_KV_ATTN_LOGIT_SOFTCAPPING, + LLM_KV_FINAL_LOGIT_SOFTCAPPING, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -315,6 +344,8 @@ enum llm_kv { LLM_KV_ATTENTION_CAUSAL, LLM_KV_ATTENTION_Q_LORA_RANK, LLM_KV_ATTENTION_KV_LORA_RANK, + LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, + LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, @@ -352,15 +383,21 @@ enum llm_kv { LLM_KV_TOKENIZER_ADD_BOS, LLM_KV_TOKENIZER_ADD_EOS, LLM_KV_TOKENIZER_ADD_PREFIX, + LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, + LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, LLM_KV_TOKENIZER_EOT_ID, + + LLM_KV_ADAPTER_TYPE, + LLM_KV_ADAPTER_LORA_ALPHA, }; static const std::map LLM_KV_NAMES = { + { LLM_KV_GENERAL_TYPE, "general.type" }, { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, @@ -373,33 +410,39 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, - { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, - { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, - { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, - { LLM_KV_BLOCK_COUNT, "%s.block_count" }, - { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, - { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, - { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, - { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, - { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, - { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, - { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, - { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, - { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, - { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, - { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, - - { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, - { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, - { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, - { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, - { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" }, - { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, - { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, - { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, - { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, - { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, - { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, + { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, + { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, + { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_BLOCK_COUNT, "%s.block_count" }, + { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, + { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, + { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, + { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" }, + { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, + { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, + { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, + { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, + { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, + { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, + { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, + { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, + { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, + { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, + { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, + + { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, + { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, + { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, + { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, + { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" }, + { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, + { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, + { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, + { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, + { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, + { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, + { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, + { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, @@ -420,29 +463,34 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, - { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, - { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, - { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, - { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, - { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, - { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, - { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, - { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, - { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, - { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, - { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, - { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, - { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, - { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, - { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, - { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, - { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, - { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, - { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, - { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, - { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, + { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, + { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, + { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, + { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, + { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, + { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, + { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, + { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, + { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, + { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, + { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, + { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, + { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, + { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, + { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, + { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, + { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, + { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + + { LLM_KV_ADAPTER_TYPE, "adapter.type" }, + { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, }; struct LLM_KV { @@ -473,10 +521,12 @@ enum llm_tensor { LLM_TENSOR_ATTN_NORM, LLM_TENSOR_ATTN_NORM_2, LLM_TENSOR_ATTN_OUT_NORM, + LLM_TENSOR_ATTN_POST_NORM, LLM_TENSOR_ATTN_ROT_EMBD, LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_POST_NORM, LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, @@ -507,6 +557,36 @@ enum llm_tensor { LLM_TENSOR_ATTN_KV_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_SUB_NORM, + LLM_TENSOR_FFN_SUB_NORM, + LLM_TENSOR_DEC_ATTN_NORM, + LLM_TENSOR_DEC_ATTN_Q, + LLM_TENSOR_DEC_ATTN_K, + LLM_TENSOR_DEC_ATTN_V, + LLM_TENSOR_DEC_ATTN_OUT, + LLM_TENSOR_DEC_ATTN_REL_B, + LLM_TENSOR_DEC_CROSS_ATTN_NORM, + LLM_TENSOR_DEC_CROSS_ATTN_Q, + LLM_TENSOR_DEC_CROSS_ATTN_K, + LLM_TENSOR_DEC_CROSS_ATTN_V, + LLM_TENSOR_DEC_CROSS_ATTN_OUT, + LLM_TENSOR_DEC_CROSS_ATTN_REL_B, + LLM_TENSOR_DEC_FFN_NORM, + LLM_TENSOR_DEC_FFN_GATE, + LLM_TENSOR_DEC_FFN_DOWN, + LLM_TENSOR_DEC_FFN_UP, + LLM_TENSOR_DEC_OUTPUT_NORM, + LLM_TENSOR_ENC_ATTN_NORM, + LLM_TENSOR_ENC_ATTN_Q, + LLM_TENSOR_ENC_ATTN_K, + LLM_TENSOR_ENC_ATTN_V, + LLM_TENSOR_ENC_ATTN_OUT, + LLM_TENSOR_ENC_ATTN_REL_B, + LLM_TENSOR_ENC_FFN_NORM, + LLM_TENSOR_ENC_FFN_GATE, + LLM_TENSOR_ENC_FFN_DOWN, + LLM_TENSOR_ENC_FFN_UP, + LLM_TENSOR_ENC_OUTPUT_NORM, }; static const std::map> LLM_TENSOR_NAMES = { @@ -717,6 +797,7 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, @@ -968,6 +1049,24 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_GEMMA2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_STARCODER2, { @@ -1068,6 +1167,22 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_OPENELM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_ARCTIC, { @@ -1119,6 +1234,89 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_CHATGLM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_BITNET, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + }, + }, + { + LLM_ARCH_T5, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_DEC_OUTPUT_NORM, "dec.output_norm" }, + { LLM_TENSOR_DEC_ATTN_NORM, "dec.blk.%d.attn_norm" }, + { LLM_TENSOR_DEC_ATTN_Q, "dec.blk.%d.attn_q" }, + { LLM_TENSOR_DEC_ATTN_K, "dec.blk.%d.attn_k" }, + { LLM_TENSOR_DEC_ATTN_V, "dec.blk.%d.attn_v" }, + { LLM_TENSOR_DEC_ATTN_OUT, "dec.blk.%d.attn_o" }, + { LLM_TENSOR_DEC_ATTN_REL_B, "dec.blk.%d.attn_rel_b" }, + { LLM_TENSOR_DEC_CROSS_ATTN_NORM, "dec.blk.%d.cross_attn_norm" }, + { LLM_TENSOR_DEC_CROSS_ATTN_Q, "dec.blk.%d.cross_attn_q" }, + { LLM_TENSOR_DEC_CROSS_ATTN_K, "dec.blk.%d.cross_attn_k" }, + { LLM_TENSOR_DEC_CROSS_ATTN_V, "dec.blk.%d.cross_attn_v" }, + { LLM_TENSOR_DEC_CROSS_ATTN_OUT, "dec.blk.%d.cross_attn_o" }, + { LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "dec.blk.%d.cross_attn_rel_b" }, + { LLM_TENSOR_DEC_FFN_NORM, "dec.blk.%d.ffn_norm" }, + { LLM_TENSOR_DEC_FFN_GATE, "dec.blk.%d.ffn_gate" }, + { LLM_TENSOR_DEC_FFN_DOWN, "dec.blk.%d.ffn_down" }, + { LLM_TENSOR_DEC_FFN_UP, "dec.blk.%d.ffn_up" }, + { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, + { LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" }, + { LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" }, + { LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" }, + { LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" }, + { LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" }, + { LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" }, + { LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" }, + { LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" }, + { LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" }, + { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_JAIS, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1286,6 +1484,126 @@ struct no_init { }; struct llama_file { + +#if defined(_WIN32) + // use FILE * so we don't have to re-open the file to mmap + FILE * fp; + HANDLE fp_win32; + size_t size; + +private: + std::string GetErrorMessageWin32(DWORD error_code) const { + std::string ret; + LPSTR lpMsgBuf = NULL; + DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL); + if (!bufLen) { + ret = format("Win32 error code: %s", error_code); + } else { + ret = lpMsgBuf; + LocalFree(lpMsgBuf); + } + + return ret; + } + +public: + + llama_file(const char * fname, const char * mode) { + fp = lm_ggml_fopen(fname, mode); + if (fp == NULL) { + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + } + fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp)); + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + + size_t tell() const { + // SetFilePointerEx returns the current position when seeking relative 0 bytes + LARGE_INTEGER li; + li.QuadPart = 0; + BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + + return li.QuadPart; + } + + void seek(size_t offset, int whence) const { + // no need to convert SEEK_* to FILE_*. The enums are the same. + // Still, keep static asserts to avoid failures in the future. + static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN"); + static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT"); + static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END"); + + LARGE_INTEGER li; + li.QuadPart = offset; + BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + } + + void read_raw(void * ptr, size_t len) const { + // On Win32 ReadFile is significant faster than fread which is again significant faster than std::fstream. Thus + // use the Win32 API to do file io instead of the C/C++ library functions. + + // There are conditions under which ReadFile cannot read chunks >64MB. + // Thus split the operation into smaller chunks if len exceeds this limit. + size_t bytes_read = 0; + while (bytes_read < len) { + size_t chunk_size = std::min(len - bytes_read, 64*1024*1024); + DWORD chunk_read = 0; + BOOL result = ReadFile(fp_win32, reinterpret_cast(ptr) + bytes_read, chunk_size, &chunk_read, NULL); + if (!result) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + if (chunk_read < chunk_size || chunk_read == 0) { + throw std::runtime_error("unexpectedly reached end of file"); + } + + bytes_read += chunk_read; + } ; + } + + uint32_t read_u32() const { + uint32_t val; + read_raw(&val, sizeof(val)); + return val; + } + + void write_raw(const void * ptr, size_t len) const { + // There are conditions under which WriteFile cannot write chunks >64MB. + // Thus split the operation into smaller chunks if len exceeds this limit. + size_t bytes_written = 0; + while (bytes_written < len) { + size_t chunk_size = std::min(len - bytes_written, 64*1024*1024); + DWORD chunk_written = 0; + BOOL result = WriteFile(fp_win32, reinterpret_cast(ptr) + bytes_written, chunk_size, &chunk_written, NULL); + if (!result) { + throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + if (chunk_written < chunk_size || chunk_written == 0) { + throw std::runtime_error("unexpectedly failed to write bytes"); + } + + bytes_written += chunk_written; + } + } + + void write_u32(std::uint32_t val) const { + write_raw(&val, sizeof(val)); + } + + ~llama_file() { + if (fp) { + std::fclose(fp); + } + } +#else // use FILE * so we don't have to re-open the file to mmap FILE * fp; size_t size; @@ -1306,7 +1624,10 @@ struct llama_file { #else long ret = std::ftell(fp); #endif - LM_GGML_ASSERT(ret != -1); // this really shouldn't fail + if (ret == -1) { + throw std::runtime_error(format("ftell error: %s", strerror(errno))); + } + return (size_t) ret; } @@ -1316,7 +1637,9 @@ struct llama_file { #else int ret = std::fseek(fp, (long) offset, whence); #endif - LM_GGML_ASSERT(ret == 0); // same + if (ret != 0) { + throw std::runtime_error(format("seek error: %s", strerror(errno))); + } } void read_raw(void * ptr, size_t len) const { @@ -1359,6 +1682,7 @@ struct llama_file { std::fclose(fp); } } +#endif }; using llama_files = std::vector>; @@ -1713,19 +2037,21 @@ struct llama_mlock { }; using llama_mlocks = std::vector>; -static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { - std::vector result(8, 0); - const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); - if (n_tokens < 0) { - result.resize(-n_tokens); - int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); - LM_GGML_ASSERT(check == -n_tokens); +// NOTE: avoid ever using this except for building the token_to_piece caches +static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) { + std::string piece; + piece.resize(piece.capacity()); // using string internal cache + const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); + if (n_chars < 0) { + piece.resize(-n_chars); + int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); + LM_GGML_ASSERT(check == -n_chars); } else { - result.resize(n_tokens); + piece.resize(n_chars); } - return std::string(result.data(), result.size()); + return piece; } static lm_ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer) { @@ -1766,6 +2092,8 @@ struct llama_state { lm_ggml_backend_metal_log_set_callback(log_callback, log_callback_user_data); #elif defined(LM_GGML_USE_CUDA) lm_ggml_backend_cuda_log_set_callback(log_callback, log_callback_user_data); +#elif defined(LM_GGML_USE_CANN) + lm_ggml_backend_cann_log_set_callback(log_callback, log_callback_user_data); #endif } @@ -1783,22 +2111,34 @@ enum e_model { MODEL_17M, MODEL_22M, MODEL_33M, + MODEL_60M, MODEL_70M, + MODEL_80M, MODEL_109M, MODEL_137M, MODEL_160M, + MODEL_220M, + MODEL_250M, + MODEL_270M, MODEL_335M, MODEL_410M, + MODEL_450M, + MODEL_770M, + MODEL_780M, MODEL_0_5B, MODEL_1B, + MODEL_1_3B, MODEL_1_4B, MODEL_2B, MODEL_2_8B, MODEL_3B, MODEL_4B, + MODEL_6B, MODEL_6_9B, MODEL_7B, MODEL_8B, + MODEL_9B, + MODEL_11B, MODEL_12B, MODEL_13B, MODEL_14B, @@ -1822,6 +2162,8 @@ enum e_model { MODEL_8x22B, MODEL_16x12B, MODEL_10B_128x3_66B, + MODEL_57B_A14B, + MODEL_27B, }; static const size_t kiB = 1024; @@ -1836,31 +2178,38 @@ struct llama_hparams { uint32_t n_vocab; uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; - uint32_t n_head; - uint32_t n_head_kv; uint32_t n_layer; uint32_t n_rot; + uint32_t n_swa = 0; // sliding window attention (SWA) uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head - uint32_t n_ff; uint32_t n_expert = 0; uint32_t n_expert_used = 0; uint32_t n_vocab_type = 0; // for BERT-style token types + uint32_t n_rel_attn_bkts = 0; + + std::array n_head_arr; + std::array n_head_kv_arr; + std::array n_ff_arr; uint32_t n_layer_dense_lead = 0; uint32_t n_lora_q = 0; uint32_t n_lora_kv = 0; uint32_t n_ff_exp = 0; + uint32_t n_ff_shexp = 0; uint32_t n_expert_shared = 0; float expert_weights_scale = 0.0; float f_norm_eps; float f_norm_rms_eps; + float f_attn_logit_softcapping = 50.0f; + float f_final_logit_softcapping = 30.0f; + float rope_attn_factor = 1.0f; float rope_freq_base_train; float rope_freq_scale_train; - uint32_t n_yarn_orig_ctx; + uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul; // for State Space Models @@ -1873,8 +2222,13 @@ struct llama_hparams { float f_max_alibi_bias = 0.0f; float f_logit_scale = 0.0f; - bool causal_attn = true; - bool use_alibi = false; + bool causal_attn = true; + bool use_alibi = false; + bool attn_soft_cap = false; + + // needed by encoder-decoder models (e.g. T5, FLAN-T5) + // ref: https://github.com/ggerganov/llama.cpp/pull/8141 + llama_token dec_start_token_id = -1; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -1885,30 +2239,36 @@ struct llama_hparams { if (this->n_vocab != other.n_vocab) return true; if (this->n_ctx_train != other.n_ctx_train) return true; if (this->n_embd != other.n_embd) return true; - if (this->n_head != other.n_head) return true; - if (this->n_head_kv != other.n_head_kv) return true; if (this->n_layer != other.n_layer) return true; if (this->n_rot != other.n_rot) return true; + if (this->n_swa != other.n_swa) return true; if (this->n_embd_head_k != other.n_embd_head_k) return true; if (this->n_embd_head_v != other.n_embd_head_v) return true; - if (this->n_ff != other.n_ff) return true; if (this->n_expert != other.n_expert) return true; if (this->n_expert_used != other.n_expert_used) return true; + if (this->n_head_arr != other.n_head_arr) return true; + if (this->n_head_kv_arr != other.n_head_kv_arr) return true; + if (this->n_ff_arr != other.n_ff_arr) return true; + + if (this->n_rel_attn_bkts != other.n_rel_attn_bkts) return true; if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true; if (this->n_lora_q != other.n_lora_q) return true; if (this->n_lora_kv != other.n_lora_kv) return true; if (this->n_ff_exp != other.n_ff_exp) return true; + if (this->n_ff_shexp != other.n_ff_shexp) return true; if (this->n_expert_shared != other.n_expert_shared) return true; if (this->rope_finetuned != other.rope_finetuned) return true; - if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; + if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true; if (this->ssm_d_conv != other.ssm_d_conv) return true; if (this->ssm_d_inner != other.ssm_d_inner) return true; if (this->ssm_d_state != other.ssm_d_state) return true; if (this->ssm_dt_rank != other.ssm_dt_rank) return true; + if (this->dec_start_token_id != other.dec_start_token_id) return true; + const float EPSILON = 1e-9f; if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; @@ -1922,18 +2282,53 @@ struct llama_hparams { return false; } - uint32_t n_gqa() const { + uint32_t n_head(uint32_t il = 0) const { + if (il < n_layer) { + return n_head_arr[il]; + } + + LM_GGML_ASSERT(false); + return 0; + } + + uint32_t n_head_kv(uint32_t il = 0) const { + if (il < n_layer) { + return n_head_kv_arr[il]; + } + + LM_GGML_ASSERT(false); + return 0; + } + + uint32_t n_ff(uint32_t il = 0) const { + if (il < n_layer) { + return n_ff_arr[il]; + } + + LM_GGML_ASSERT(false); + return 0; + } + + uint32_t n_gqa(uint32_t il = 0) const { + const uint32_t n_head = this->n_head(il); + const uint32_t n_head_kv = this->n_head_kv(il); + if (n_head_kv == 0) { return 0; } + return n_head/n_head_kv; } - uint32_t n_embd_k_gqa() const { // dimension of key embeddings across all k-v heads + uint32_t n_embd_k_gqa(uint32_t il = 0) const { // dimension of key embeddings across all k-v heads + const uint32_t n_head_kv = this->n_head_kv(il); + return n_embd_head_k * n_head_kv; } - uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads + uint32_t n_embd_v_gqa(uint32_t il = 0) const { // dimension of value embeddings across all k-v heads + const uint32_t n_head_kv = this->n_head_kv(il); + return n_embd_head_v * n_head_kv; } @@ -1950,6 +2345,8 @@ struct llama_hparams { } }; +static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); + struct llama_cparams { uint32_t n_ctx; // context size used during inference uint32_t n_batch; @@ -1961,7 +2358,7 @@ struct llama_cparams { float rope_freq_base; float rope_freq_scale; - uint32_t n_yarn_orig_ctx; + uint32_t n_ctx_orig_yarn; // These hyperparameters are not exposed in GGUF, because all // existing YaRN models use the same values for them. float yarn_ext_factor; @@ -1981,6 +2378,7 @@ struct llama_cparams { void * cb_eval_user_data; }; +// TODO: separate into "llama_layer_enc" and "llama_layer_dec" struct llama_layer { // normalization struct lm_ggml_tensor * attn_norm; @@ -1995,6 +2393,11 @@ struct llama_layer { struct lm_ggml_tensor * attn_out_norm_b; struct lm_ggml_tensor * attn_q_a_norm; struct lm_ggml_tensor * attn_kv_a_norm; + struct lm_ggml_tensor * attn_sub_norm; + struct lm_ggml_tensor * attn_post_norm; + struct lm_ggml_tensor * ffn_sub_norm; + struct lm_ggml_tensor * attn_norm_cross; + struct lm_ggml_tensor * attn_norm_enc; // attention struct lm_ggml_tensor * wq; @@ -2006,6 +2409,14 @@ struct llama_layer { struct lm_ggml_tensor * wq_b; struct lm_ggml_tensor * wkv_a_mqa; struct lm_ggml_tensor * wkv_b; + struct lm_ggml_tensor * wq_cross; + struct lm_ggml_tensor * wk_cross; + struct lm_ggml_tensor * wv_cross; + struct lm_ggml_tensor * wo_cross; + struct lm_ggml_tensor * wq_enc; + struct lm_ggml_tensor * wk_enc; + struct lm_ggml_tensor * wv_enc; + struct lm_ggml_tensor * wo_enc; // attention bias struct lm_ggml_tensor * bq; @@ -2014,17 +2425,27 @@ struct llama_layer { struct lm_ggml_tensor * bo; struct lm_ggml_tensor * bqkv; + // relative position bias + struct lm_ggml_tensor * attn_rel_b; + struct lm_ggml_tensor * attn_rel_b_enc; + struct lm_ggml_tensor * attn_rel_b_cross; + // normalization struct lm_ggml_tensor * ffn_norm; struct lm_ggml_tensor * ffn_norm_b; + struct lm_ggml_tensor * ffn_post_norm; struct lm_ggml_tensor * layer_out_norm; struct lm_ggml_tensor * layer_out_norm_b; struct lm_ggml_tensor * ffn_norm_exps; + struct lm_ggml_tensor * ffn_norm_enc; // ff struct lm_ggml_tensor * ffn_gate; // w1 struct lm_ggml_tensor * ffn_down; // w2 struct lm_ggml_tensor * ffn_up; // w3 + struct lm_ggml_tensor * ffn_gate_enc; + struct lm_ggml_tensor * ffn_down_enc; + struct lm_ggml_tensor * ffn_up_enc; // ff MoE struct lm_ggml_tensor * ffn_gate_inp; @@ -2062,6 +2483,15 @@ struct llama_layer { // long rope factors struct lm_ggml_tensor * rope_long = nullptr; struct lm_ggml_tensor * rope_short = nullptr; + + // bitnet scale + struct lm_ggml_tensor * wq_scale; + struct lm_ggml_tensor * wk_scale; + struct lm_ggml_tensor * wv_scale; + struct lm_ggml_tensor * wo_scale; + struct lm_ggml_tensor * ffn_gate_scale; + struct lm_ggml_tensor * ffn_up_scale; + struct lm_ggml_tensor * ffn_down_scale; }; struct llama_kv_cell { @@ -2139,13 +2569,21 @@ struct llama_control_vector { int32_t layer_start = -1; int32_t layer_end = -1; - lm_ggml_tensor * tensor_for(int il) const { + struct lm_ggml_tensor * tensor_for(int il) const { if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) { return nullptr; } return tensors[il]; } + struct lm_ggml_tensor * apply_to(struct lm_ggml_context * ctx, struct lm_ggml_tensor * cur, int il) const { + lm_ggml_tensor * layer_dir = tensor_for(il); + if (layer_dir != nullptr) { + cur = lm_ggml_add(ctx, cur, layer_dir); + } + return cur; + } + ~llama_control_vector() { for (struct lm_ggml_context * ctx : ctxs) { lm_ggml_free(ctx); @@ -2159,21 +2597,24 @@ struct llama_control_vector { struct llama_vocab { using id = int32_t; using token = std::string; - using ttype = llama_token_type; + using tattr = llama_token_attr; struct token_data { token text; float score; - ttype type; + tattr attr; }; enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + int max_token_len = 0; // used for optimizing longest token search + std::unordered_map token_to_id; std::vector id_to_token; - std::vector special_tokens_cache; + std::vector cache_special_tokens; + std::vector cache_token_to_piece; // llama_token_to_piece(special = true); std::map, int> bpe_ranks; @@ -2186,16 +2627,23 @@ struct llama_vocab { id special_cls_id = -1; id special_mask_id = -1; - int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add. - int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add. - id linefeed_id = 13; id special_prefix_id = -1; id special_suffix_id = -1; id special_middle_id = -1; id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token - bool add_space_prefix = true; + // tokenizer flags + bool tokenizer_add_space_prefix = false; + bool tokenizer_add_bos = false; + bool tokenizer_add_eos = false; + bool tokenizer_ignore_merges = false; + bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces + bool tokenizer_remove_extra_whitespaces = false; + bool tokenizer_escape_whitespaces = true; + bool tokenizer_treat_whitespace_as_suffix = false; + + std::vector precompiled_charsmap; int find_bpe_rank(const std::string & token_left, const std::string & token_right) const { LM_GGML_ASSERT(token_left.find(' ') == std::string::npos); @@ -2232,6 +2680,7 @@ struct llama_model { struct lm_ggml_tensor * output_norm_b; struct lm_ggml_tensor * output; struct lm_ggml_tensor * output_b; + struct lm_ggml_tensor * output_norm_enc; std::vector layers; @@ -2277,6 +2726,9 @@ struct llama_model { int64_t t_load_us = 0; int64_t t_start_us = 0; + // keep track of loaded lora adapters + std::set lora_adapters; + ~llama_model() { for (struct lm_ggml_context * ctx : ctxs) { lm_ggml_free(ctx); @@ -2289,6 +2741,9 @@ struct llama_model { #endif lm_ggml_backend_buffer_free(buf); } + while (!lora_adapters.empty()) { + llama_lora_adapter_free(*lora_adapters.begin()); + } } }; @@ -2309,9 +2764,13 @@ struct llama_context { std::vector backends; #ifdef LM_GGML_USE_METAL lm_ggml_backend_t backend_metal = nullptr; +#endif +#ifdef LM_GGML_USE_BLAS + lm_ggml_backend_t backend_blas = nullptr; #endif lm_ggml_backend_t backend_cpu = nullptr; + const llama_model & model; // key + value cache for the self attention @@ -2356,6 +2815,13 @@ struct llama_context { // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE std::map> embd_seq; + // whether we are computing encoder output or decoder output + bool is_encoding = false; + + // output of the encoder part of the encoder-decoder models + std::vector embd_enc; + std::vector> seq_ids_enc; + // memory buffers used to evaluate the model std::vector buf_compute_meta; lm_ggml_backend_sched_t sched = nullptr; @@ -2364,29 +2830,102 @@ struct llama_context { void * abort_callback_data = nullptr; // input tensors - struct lm_ggml_tensor * inp_tokens; // I32 [n_batch] - struct lm_ggml_tensor * inp_embd; // F32 [n_embd, n_batch] - struct lm_ggml_tensor * inp_pos; // I32 [n_batch] - struct lm_ggml_tensor * inp_out_ids; // I32 [n_outputs] - struct lm_ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] - struct lm_ggml_tensor * inp_K_shift; // I32 [kv_size] - struct lm_ggml_tensor * inp_mean; // F32 [n_batch, n_batch] - struct lm_ggml_tensor * inp_cls; // I32 [n_batch] - struct lm_ggml_tensor * inp_s_copy; // I32 [kv_size] - struct lm_ggml_tensor * inp_s_mask; // F32 [1, n_kv] - struct lm_ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + struct lm_ggml_tensor * inp_tokens; // I32 [n_batch] + struct lm_ggml_tensor * inp_embd; // F32 [n_embd, n_batch] + struct lm_ggml_tensor * inp_pos; // I32 [n_batch] + struct lm_ggml_tensor * inp_out_ids; // I32 [n_outputs] + struct lm_ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] + struct lm_ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch] + struct lm_ggml_tensor * inp_K_shift; // I32 [kv_size] + struct lm_ggml_tensor * inp_mean; // F32 [n_batch, n_batch] + struct lm_ggml_tensor * inp_cls; // I32 [n_batch] + struct lm_ggml_tensor * inp_s_copy; // I32 [kv_size] + struct lm_ggml_tensor * inp_s_mask; // F32 [1, n_kv] + struct lm_ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + struct lm_ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] + struct lm_ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] + struct lm_ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] // control vectors struct llama_control_vector cvec; + + // lora adapters and scales + std::unordered_map lora_adapters; +}; + +struct llama_lora_weight { + struct lm_ggml_tensor * a = nullptr; + struct lm_ggml_tensor * b = nullptr; + llama_lora_weight() = default; + llama_lora_weight(struct lm_ggml_tensor * a, struct lm_ggml_tensor * b): a(a), b(b) {} +}; + +struct llama_lora_adapter { + struct llama_model * base_model; + // map tensor name to lora_a_b + std::unordered_map ab_map; + std::vector ctxs; + std::vector bufs; + + float alpha; + + llama_lora_adapter(struct llama_model * base_model): base_model(base_model) { + base_model->lora_adapters.insert(this); + } + + llama_lora_weight * get_weight(struct lm_ggml_tensor * w) { + std::string name(w->name); + auto pos = ab_map.find(name); + if (ab_map.find(name) != ab_map.end()) { + return &pos->second; + } + return nullptr; + } + + ~llama_lora_adapter() { + for (struct lm_ggml_context * ctx : ctxs) { + lm_ggml_free(ctx); + } + for (lm_ggml_backend_buffer_t buf : bufs) { + lm_ggml_backend_buffer_free(buf); + } + auto pos = base_model->lora_adapters.find(this); + if (pos != base_model->lora_adapters.end()) { + base_model->lora_adapters.erase(pos); + } + } }; +static size_t llama_get_device_count(const llama_model & model) { + size_t count = 1; +#if defined(LM_GGML_USE_CUDA) + count = lm_ggml_backend_cuda_get_device_count(); +#elif defined(LM_GGML_USE_SYCL) + count = lm_ggml_backend_sycl_get_device_count(); +#elif defined(LM_GGML_USE_VULKAN) + count = lm_ggml_backend_vk_get_device_count(); +#elif defined(LM_GGML_USE_CANN) + return lm_ggml_backend_cann_get_device_count(); +#endif +#if defined(LM_GGML_USE_RPC) + count += model.rpc_servers.size(); +#endif + return count; + LM_GGML_UNUSED(model); +} + static lm_ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_model & model, int gpu) { lm_ggml_backend_buffer_type_t buft = nullptr; -#ifdef LM_GGML_USE_RPC - std::string endpoint = model.rpc_servers[gpu]; - buft = lm_ggml_backend_rpc_buffer_type(endpoint.c_str()); -#elif defined(LM_GGML_USE_METAL) +#if defined(LM_GGML_USE_RPC) + int dev_count = (int)llama_get_device_count(model); + int rpc_count = (int)model.rpc_servers.size(); + if (gpu >= dev_count - rpc_count) { + const char * endpoint = model.rpc_servers[gpu - dev_count + rpc_count].c_str(); + return lm_ggml_backend_rpc_buffer_type(endpoint); + } +#endif +#if defined(LM_GGML_USE_METAL) buft = lm_ggml_backend_metal_buffer_type(); #elif defined(LM_GGML_USE_CUDA) buft = lm_ggml_backend_cuda_buffer_type(gpu); @@ -2394,13 +2933,13 @@ static lm_ggml_backend_buffer_type_t llama_default_buffer_type_offload(const lla buft = lm_ggml_backend_vk_buffer_type(gpu); #elif defined(LM_GGML_USE_SYCL) buft = lm_ggml_backend_sycl_buffer_type(gpu); -#elif defined(LM_GGML_USE_CLBLAST) - buft = lm_ggml_backend_opencl_buffer_type(); #elif defined(LM_GGML_USE_KOMPUTE) buft = lm_ggml_backend_kompute_buffer_type(gpu); if (buft == nullptr) { LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, gpu); } +#elif defined(LM_GGML_USE_CANN) + buft = lm_ggml_backend_cann_buffer_type(gpu); #endif if (buft == nullptr) { @@ -2434,29 +2973,19 @@ static lm_ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama LM_GGML_UNUSED(tensor_split); } -static size_t llama_get_device_count(const llama_model & model) { -#if defined(LM_GGML_USE_RPC) - return model.rpc_servers.size(); -#elif defined(LM_GGML_USE_CUDA) - return lm_ggml_backend_cuda_get_device_count(); -#elif defined(LM_GGML_USE_SYCL) - return lm_ggml_backend_sycl_get_device_count(); -#elif defined(LM_GGML_USE_VULKAN) - return lm_ggml_backend_vk_get_device_count(); -#else - return 1; -#endif - LM_GGML_UNUSED(model); -} - static size_t llama_get_device_memory(const llama_model & model, int device) { #if defined(LM_GGML_USE_RPC) - size_t total; - size_t free; - std::string endpoint = model.rpc_servers[device]; - lm_ggml_backend_rpc_get_device_memory(endpoint.c_str(), &free, &total); - return free; -#elif defined(LM_GGML_USE_CUDA) + int dev_count = (int)llama_get_device_count(model); + int rpc_count = (int)model.rpc_servers.size(); + if (device >= dev_count - rpc_count) { + size_t total; + size_t free; + const char * endpoint = model.rpc_servers[device - dev_count + rpc_count].c_str(); + lm_ggml_backend_rpc_get_device_memory(endpoint, &free, &total); + return free; + } +#endif +#if defined(LM_GGML_USE_CUDA) size_t total; size_t free; lm_ggml_backend_cuda_get_device_memory(device, &free, &total); @@ -2471,6 +3000,11 @@ static size_t llama_get_device_memory(const llama_model & model, int device) { size_t free; lm_ggml_backend_vk_get_device_memory(device, &free, &total); return free; +#elif defined(LM_GGML_USE_CANN) + size_t total; + size_t free; + lm_ggml_backend_cann_get_device_memory(device, &total, &free); + return free; #else return 1; #endif @@ -2494,9 +3028,7 @@ static bool llama_kv_cache_init( const struct llama_hparams & hparams = model.hparams; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); - const int64_t n_layer = hparams.n_layer; + const int64_t n_layer = hparams.n_layer; cache.has_shift = false; @@ -2504,13 +3036,6 @@ static bool llama_kv_cache_init( cache.recurrent = model.arch == LLM_ARCH_MAMBA; cache.v_trans = !cparams.flash_attn; - // TODO: support mixed recurrent Transformer architectures - // NOTE: (!a || b) is a logical implication (a -> b) - LM_GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s()); - LM_GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s()); - LM_GGML_ASSERT( cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_gqa()); - LM_GGML_ASSERT( cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_gqa()); - cache.head = 0; cache.size = kv_size; cache.used = 0; @@ -2528,10 +3053,6 @@ static bool llama_kv_cache_init( } } -#ifdef LM_GGML_USE_CLBLAST - offload = false; -#endif - // count used buffer types std::map buft_layer_count; if (offload) { @@ -2564,6 +3085,9 @@ static bool llama_kv_cache_init( cache.v_l.reserve(n_layer); for (int i = 0; i < (int) n_layer; i++) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + struct lm_ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); lm_ggml_tensor * k = lm_ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); lm_ggml_tensor * v = lm_ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); @@ -2844,6 +3368,8 @@ static void llama_kv_cache_seq_add( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) return; if (cache.recurrent) { // for Mamba-like models, only the pos needs to be shifted @@ -2888,6 +3414,8 @@ static void llama_kv_cache_seq_div( int d) { if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) return; if (cache.recurrent) { // for Mamba-like models, only the pos needs to be changed @@ -3344,6 +3872,9 @@ struct llama_model_loader { case LM_GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case LM_GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case LM_GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; + case LM_GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break; + case LM_GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break; + case LM_GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, lm_ggml_type_name(type_max)); @@ -3439,9 +3970,9 @@ struct llama_model_loader { bool get_arr(const std::string & key, std::vector & result, const bool required = true) { const int kid = lm_gguf_find_key(meta, key.c_str()); - if (kid < 0) { + if (kid < 0 || lm_gguf_get_kv_type(meta, kid) != LM_GGUF_TYPE_ARRAY) { if (required) { - throw std::runtime_error(format("key not found in model: %s", key.c_str())); + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); } return false; } @@ -3449,22 +3980,55 @@ struct llama_model_loader { struct GGUFMeta::ArrayInfo arr_info = GGUFMeta::GKV::get_kv(meta, kid); - if (arr_info.gt != LM_GGUF_TYPE_FLOAT32 && arr_info.gt != LM_GGUF_TYPE_INT32) { - throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str())); + switch (arr_info.gt) { + case LM_GGUF_TYPE_FLOAT32: LM_GGML_ASSERT((std::is_same::value)); break; + case LM_GGUF_TYPE_INT32: LM_GGML_ASSERT( + (std::is_same::value) || + (std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); } - // LM_GGML_ASSERT(lm_gguf_type_size(arr_info.gt) == sizeof(T)); - LM_GGML_ASSERT((arr_info.gt != LM_GGUF_TYPE_FLOAT32 || std::is_same::value)); - LM_GGML_ASSERT((arr_info.gt != LM_GGUF_TYPE_INT32 || std::is_same::value)); - result.resize(arr_info.length); result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); return true; } + template + bool get_arr(const std::string & key, std::array & result, const bool required = true) { + const int kid = lm_gguf_find_key(meta, key.c_str()); + + if (kid < 0 || lm_gguf_get_kv_type(meta, kid) != LM_GGUF_TYPE_ARRAY) { + if (required) { + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + switch (arr_info.gt) { + case LM_GGUF_TYPE_FLOAT32: LM_GGML_ASSERT((std::is_same::value)); break; + case LM_GGUF_TYPE_INT32: LM_GGML_ASSERT( + (std::is_same::value) || + (std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + } + + if (arr_info.length > N_MAX) { + throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX)); + } + + std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + + return true; + } + template - bool get_arr(const enum llm_kv kid, T& result, const bool required = true) { + bool get_arr(const enum llm_kv kid, T & result, const bool required = true) { return get_arr(llm_kv(kid), result, required); } @@ -3489,6 +4053,52 @@ struct llama_model_loader { return get_key(llm_kv(kid), result, required); } + // get array of n <= N_MAX elements, or a single element repeated n times + template + bool get_key_or_arr(const std::string & key, std::array & result, uint32_t n, const bool required = true) { + const int kid = lm_gguf_find_key(meta, key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + if (n > N_MAX) { + throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str())); + } + + if (lm_gguf_get_kv_type(meta, kid) == LM_GGUF_TYPE_ARRAY) { + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + if (n != arr_info.length) { + throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length)); + } + + return get_arr(key, result, required); + } else { + T value; + + bool ok = get_key(key, value, required); + if (!ok) { + return false; + } + + for (uint32_t i = 0; i < n; i++) { + result[i] = value; + } + + return true; + } + } + + template + bool get_key_or_arr(const enum llm_kv kid, T & result, uint32_t n, const bool required = true) { + return get_key_or_arr(llm_kv(kid), result, n, required); + } + std::string get_arch_name() const { return arch_name; } @@ -3718,6 +4328,44 @@ struct llama_model_loader { std::vector> read_buf; std::vector>> validation_result; +#if defined(LM_GGML_USE_CUDA) + // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives. + // NVMe raid configurations might require more / larger buffers. + constexpr size_t n_buffers = 4; + constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB + + std::vector host_buffers; + std::vector host_ptrs; + std::vector events; + size_t buffer_idx = 0; // buffer to use for async loads + + lm_ggml_backend_t cuda_backend = nullptr; + if (!use_mmap && !check_tensors) { + // When not using mmaped io use async uploads from pinned memory to GPU memory. + // First determine if the CUDA backend is active, and if so, determine the device ID. + lm_ggml_backend_buffer_t buf = bufs_mmap.count(0) ? bufs_mmap.at(0) : nullptr; + if (buf) { + lm_ggml_backend_buffer_type_t buffer_type = lm_ggml_backend_buffer_get_type(buf); + for (int i = 0; i < lm_ggml_backend_cuda_get_device_count(); ++i) { + auto * cuda_buffer_type = lm_ggml_backend_cuda_buffer_type(i); + if (buffer_type == cuda_buffer_type) { + cuda_backend = lm_ggml_backend_cuda_init(i); + break; + } + } + } + + // If the cuda backend is active create pinned memory buffers and events for synchronisation. + if (cuda_backend) { + for (size_t idx = 0; idx < n_buffers; ++idx) { + host_buffers.emplace_back(lm_ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buffer_size)); + host_ptrs.emplace_back(lm_ggml_backend_buffer_get_base(host_buffers[idx])); + events.emplace_back(lm_ggml_backend_event_new(cuda_backend)); + } + } + } +#endif + for (struct lm_ggml_tensor * cur = lm_ggml_get_first_tensor(ctx); cur != NULL; cur = lm_ggml_get_next_tensor(ctx, cur)) { const auto * weight = get_weight(lm_ggml_get_name(cur)); if (weight == nullptr) { @@ -3773,12 +4421,36 @@ struct llama_model_loader { })); } } else { - read_buf.resize(n_size); - file->seek(weight->offs, SEEK_SET); - file->read_raw(read_buf.data(), n_size); - lm_ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); - if (check_tensors && !lm_ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { - throw std::runtime_error(format("tensor '%s' has invalid data", lm_ggml_get_name(cur))); +#if defined(LM_GGML_USE_CUDA) + // If cuda_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU. + if (cuda_backend) { + file->seek(weight->offs, SEEK_SET); + + size_t bytes_read = 0; + + while (bytes_read < n_size) { + size_t read_iteration = std::min(buffer_size, n_size - bytes_read); + + lm_ggml_backend_event_synchronize(events[buffer_idx]); + file->read_raw(host_ptrs[buffer_idx], read_iteration); + lm_ggml_backend_tensor_set_async(cuda_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration); + lm_ggml_backend_event_record(events[buffer_idx]); + + bytes_read += read_iteration; + ++buffer_idx; + buffer_idx %= n_buffers; + } + } + else +#endif + { + read_buf.resize(n_size); + file->seek(weight->offs, SEEK_SET); + file->read_raw(read_buf.data(), n_size); + lm_ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); + if (check_tensors && !lm_ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { + throw std::runtime_error(format("tensor '%s' has invalid data", lm_ggml_get_name(cur))); + } } } } @@ -3786,6 +4458,18 @@ struct llama_model_loader { size_done += n_size; } +#if defined(LM_GGML_USE_CUDA) + // free temporary resources used for async cuda uploads + if (cuda_backend) { + for (size_t idx = 0; idx < n_buffers;++idx) { + lm_ggml_backend_event_synchronize(events[idx]); + lm_ggml_backend_event_free(events[idx]); + lm_ggml_backend_buffer_free(host_buffers[idx]); + } + lm_ggml_backend_free(cuda_backend); + } +#endif + // check validation results bool validation_failed = false; for (auto & future : validation_result) { @@ -3854,40 +4538,39 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { } switch (ftype) { - case LLAMA_FTYPE_ALL_F32: return "all F32"; - case LLAMA_FTYPE_MOSTLY_F16: return "F16"; - case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; - case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; - case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; - case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: - return "Q4_1, some F16"; - case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; - case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; - case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; - - // K-quants - case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; - case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; - case LLAMA_FTYPE_MOSTLY_IQ2_XXS:return "IQ2_XXS - 2.0625 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ1_S :return "IQ1_S - 1.5625 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ1_M :return "IQ1_M - 1.75 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; + case LLAMA_FTYPE_ALL_F32: return "all F32"; + case LLAMA_FTYPE_MOSTLY_F16: return "F16"; + case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; + case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; + case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; + case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; + case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; + case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; + case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; + case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4"; + case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: return "Q4_0_4_8"; + case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: return "Q4_0_8_8"; default: return "unknown, may not work"; } @@ -3899,22 +4582,34 @@ static const char * llama_model_type_name(e_model type) { case MODEL_17M: return "17M"; case MODEL_22M: return "22M"; case MODEL_33M: return "33M"; + case MODEL_60M: return "60M"; case MODEL_70M: return "70M"; + case MODEL_80M: return "80M"; case MODEL_109M: return "109M"; case MODEL_137M: return "137M"; case MODEL_160M: return "160M"; + case MODEL_220M: return "220M"; + case MODEL_250M: return "250M"; + case MODEL_270M: return "270M"; case MODEL_335M: return "335M"; case MODEL_410M: return "410M"; + case MODEL_450M: return "450M"; + case MODEL_770M: return "770M"; + case MODEL_780M: return "780M"; case MODEL_0_5B: return "0.5B"; case MODEL_1B: return "1B"; + case MODEL_1_3B: return "1.3B"; case MODEL_1_4B: return "1.4B"; case MODEL_2B: return "2B"; case MODEL_2_8B: return "2.8B"; case MODEL_3B: return "3B"; case MODEL_4B: return "4B"; + case MODEL_6B: return "6B"; case MODEL_6_9B: return "6.9B"; case MODEL_7B: return "7B"; case MODEL_8B: return "8B"; + case MODEL_9B: return "9B"; + case MODEL_11B: return "11B"; case MODEL_12B: return "12B"; case MODEL_13B: return "13B"; case MODEL_14B: return "14B"; @@ -3938,6 +4633,8 @@ static const char * llama_model_type_name(e_model type) { case MODEL_8x22B: return "8x22B"; case MODEL_16x12B: return "16x12B"; case MODEL_10B_128x3_66B: return "10B+128x3.66B"; + case MODEL_57B_A14B: return "57B.A14B"; + case MODEL_27B: return "27B"; default: return "?B"; } } @@ -3948,6 +4645,7 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){ case LLAMA_VOCAB_TYPE_SPM: return "SPM"; case LLAMA_VOCAB_TYPE_BPE: return "BPE"; case LLAMA_VOCAB_TYPE_WPM: return "WPM"; + case LLAMA_VOCAB_TYPE_UGM: return "UGM"; default: return "unknown"; } } @@ -3980,20 +4678,18 @@ static void llm_load_hparams( ml.get_key(LLM_KV_GENERAL_NAME, model.name, false); // get hparams kv - ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab); + ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab); // everything past this point is not vocab-related if (hparams.vocab_only) { return; } - ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); - ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); - ml.get_key(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff); - ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head); - ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); - ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); - ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); + ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); + ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); LM_GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS); LM_GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); @@ -4003,16 +4699,25 @@ static void llm_load_hparams( LM_GGML_ASSERT(hparams.n_expert_used == 0); } + // zero-out the per-layer hparams + std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); + std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); + std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer); + // n_head_kv is optional, default to n_head - hparams.n_head_kv = hparams.n_head; - ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false); + hparams.n_head_kv_arr = hparams.n_head_arr; + + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer, false); bool rope_finetuned = false; ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); hparams.rope_finetuned = rope_finetuned; - hparams.n_yarn_orig_ctx = hparams.n_ctx_train; - ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_yarn_orig_ctx, false); + hparams.n_ctx_orig_yarn = hparams.n_ctx_train; + ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn, false); // rope_freq_base (optional) hparams.rope_freq_base_train = 10000.0f; @@ -4033,27 +4738,33 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); - // sanity check for n_rot (optional) - { - hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; + // non-transformer models do not have attention heads + if (hparams.n_head() > 0) { + // gpt-neox n_rot = rotary_pct * (n_embd / n_head) + // gpt-j n_rot = rotary_dim + + hparams.n_embd_head_k = hparams.n_embd / hparams.n_head(); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); + + hparams.n_embd_head_v = hparams.n_embd / hparams.n_head(); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); + + // sanity check for n_rot (optional) + hparams.n_rot = hparams.n_embd_head_k; ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) { - if (hparams.n_rot != hparams.n_embd / hparams.n_head) { - throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head)); + if (hparams.n_rot != hparams.n_embd_head_k) { + throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } } - // gpt-neox n_rot = rotary_pct * (n_embd / n_head) - // gpt-j n_rot = rotary_dim + } else { + hparams.n_rot = 0; + hparams.n_embd_head_k = 0; + hparams.n_embd_head_v = 0; } - hparams.n_embd_head_k = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); - - hparams.n_embd_head_v = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); - // arch-specific KVs switch (model.arch) { case LLM_ARCH_LLAMA: @@ -4076,7 +4787,7 @@ static void llm_load_hparams( case 40: model.type = e_model::MODEL_13B; break; case 48: model.type = e_model::MODEL_34B; break; case 60: model.type = e_model::MODEL_30B; break; - case 80: model.type = hparams.n_head == hparams.n_head_kv ? e_model::MODEL_65B : e_model::MODEL_70B; break; + case 80: model.type = hparams.n_head() == hparams.n_head_kv() ? e_model::MODEL_65B : e_model::MODEL_70B; break; default: model.type = e_model::MODEL_UNKNOWN; } } @@ -4245,16 +4956,20 @@ static void llm_load_hparams( switch (hparams.n_layer) { case 24: model.type = hparams.n_embd == 1024 ? e_model::MODEL_0_5B : e_model::MODEL_1B; break; case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = hparams.n_head == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break; + case 40: model.type = hparams.n_head() == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break; case 80: model.type = e_model::MODEL_70B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; case LLM_ARCH_QWEN2MOE: { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 24: model.type = e_model::MODEL_A2_7B; break; + case 28: model.type = e_model::MODEL_57B_A14B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -4335,6 +5050,21 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_GEMMA2: + { + hparams.n_swa = 4096; // default value of gemma 2 + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + hparams.attn_soft_cap = true; + + switch (hparams.n_layer) { + case 42: model.type = e_model::MODEL_9B; break; + case 46: model.type = e_model::MODEL_27B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -4418,46 +5148,58 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_OPENELM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 16: model.type = e_model::MODEL_270M; break; + case 20: model.type = e_model::MODEL_450M; break; + case 28: model.type = e_model::MODEL_1B; break; + case 36: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_GPTNEOX: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); switch (hparams.n_layer) { case 6: - switch (hparams.n_ff) { + switch (hparams.n_ff()) { case 512: model.type = e_model::MODEL_14M; break; case 2048: model.type = e_model::MODEL_70M; break; default: model.type = e_model::MODEL_UNKNOWN; } break; case 12: - switch (hparams.n_ff) { + switch (hparams.n_ff()) { case 3072: model.type = e_model::MODEL_160M; break; default: model.type = e_model::MODEL_UNKNOWN; } break; case 16: - switch (hparams.n_ff) { + switch (hparams.n_ff()) { case 8192: model.type = e_model::MODEL_1B; break; default: model.type = e_model::MODEL_UNKNOWN; } break; case 24: - switch (hparams.n_ff) { + switch (hparams.n_ff()) { case 4096: model.type = e_model::MODEL_410M; break; case 8192: model.type = e_model::MODEL_1_4B; break; default: model.type = e_model::MODEL_UNKNOWN; } break; case 32: - switch (hparams.n_ff) { + switch (hparams.n_ff()) { case 10240: model.type = e_model::MODEL_2_8B; break; case 16384: model.type = e_model::MODEL_6_9B; break; default: model.type = e_model::MODEL_UNKNOWN; } break; case 36: - switch (hparams.n_ff) { + switch (hparams.n_ff()) { case 20480: model.type = e_model::MODEL_12B; break; default: model.type = e_model::MODEL_UNKNOWN; } break; case 44: - switch (hparams.n_ff) { + switch (hparams.n_ff()) { case 24576: model.type = e_model::MODEL_20B; break; default: model.type = e_model::MODEL_UNKNOWN; } break; @@ -4497,6 +5239,68 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_CHATGLM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: model.type = e_model::MODEL_6B; break; + case 40: model.type = e_model::MODEL_9B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_BITNET: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 26: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_T5: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + + uint32_t dec_start_token_id; + if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) { + hparams.dec_start_token_id = dec_start_token_id; + } + + switch (hparams.n_layer) { + case 6: model.type = e_model::MODEL_60M; break; // t5-small + case 8: model.type = e_model::MODEL_80M; break; // flan-t5-small + case 12: + switch (hparams.n_ff()) { + case 3072: model.type = e_model::MODEL_220M; break; // t5-base + case 2048: model.type = e_model::MODEL_250M; break; // flan-t5-base + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff()) { + case 4096: model.type = e_model::MODEL_770M; break; // t5-large + case 2816: model.type = e_model::MODEL_780M; break; // flan-t5-large + case 16384: model.type = e_model::MODEL_3B; break; // t5-3b + case 5120: model.type = e_model::MODEL_3B; break; // flan-t5-xl + case 65536: model.type = e_model::MODEL_11B; break; // t5-11b + case 10240: model.type = e_model::MODEL_11B; break; // flan-t5-xxl + default: model.type = e_model::MODEL_UNKNOWN; + } break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_JAIS: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1_3B; break; + case 40: model.type = e_model::MODEL_13B; break; + /* TODO: add variants */ + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -4557,40 +5361,6 @@ static void llm_load_vocab( vocab.special_pad_id = -1; vocab.special_cls_id = -1; vocab.special_mask_id = -1; - - // For Fill-In-the-Middle (FIM)/infill models which where converted - // prior to support of FIM special tokens in GGUF, the following - // will allow those models to continue to work. The general names - // of the known models are currently CodeLlama (LLM_ARCH_LLAMA) and - // CodeGemma (LLM_ARCH_GEMMA). This can potentially be removed once - // new versions of these models have been published. - std::string gen_name; - ml.get_key(LLM_KV_GENERAL_NAME, gen_name, false); - - std::transform(gen_name.begin(), gen_name.end(), gen_name.begin(), - [](unsigned char c){ return std::tolower(c); }); - - if (gen_name.find("code") != std::string::npos) { - if (model.arch == LLM_ARCH_LLAMA) { - vocab.special_prefix_id = 32007; - vocab.special_suffix_id = 32008; - vocab.special_middle_id = 32009; - vocab.special_eot_id = 32010; - } else if (model.arch == LLM_ARCH_GEMMA) { - vocab.special_prefix_id = 67; - vocab.special_suffix_id = 69; - vocab.special_middle_id = 68; - // TODO: this is not EOT, it is "file separator" token, needs fix - // https://huggingface.co/google/codegemma-7b-it/blob/9b1d9231388358c04d90bd003458f5070d97db44/tokenizer_config.json#L565-L572 - //vocab.special_eot_id = 70; - vocab.special_eot_id = 107; - } - } - - const int add_space_prefix_keyidx = lm_gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); - if (add_space_prefix_keyidx != -1) { - vocab.add_space_prefix = lm_gguf_get_val_bool(ctx, add_space_prefix_keyidx); - } // The default value of add_space_prefix is true. } else if (tokenizer_model == "bert") { vocab.type = LLAMA_VOCAB_TYPE_WPM; @@ -4602,29 +5372,15 @@ static void llm_load_vocab( vocab.special_pad_id = 0; vocab.special_cls_id = 101; vocab.special_mask_id = 103; - vocab.add_space_prefix = false; - } else { - if (tokenizer_model == "gpt2") { - vocab.type = LLAMA_VOCAB_TYPE_BPE; + } else if (tokenizer_model == "gpt2") { + vocab.type = LLAMA_VOCAB_TYPE_BPE; - const int add_space_prefix_keyidx = lm_gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); - if (add_space_prefix_keyidx != -1) { - vocab.add_space_prefix = lm_gguf_get_val_bool(ctx, add_space_prefix_keyidx); - } - } else { - LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_model.c_str()); - LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); - vocab.type = LLAMA_VOCAB_TYPE_SPM; - return; - } // read bpe merges and populate bpe ranks const int merges_keyidx = lm_gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); if (merges_keyidx == -1) { throw std::runtime_error("cannot find tokenizer merges in model file\n"); } - const int n_merges = lm_gguf_get_arr_n(ctx, merges_keyidx); - for (int i = 0; i < n_merges; i++) { const std::string word = lm_gguf_get_arr_str(ctx, merges_keyidx, i); LM_GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); @@ -4650,10 +5406,53 @@ static void llm_load_vocab( vocab.special_pad_id = -1; vocab.special_cls_id = -1; vocab.special_mask_id = -1; + } else if (tokenizer_model == "t5") { + vocab.type = LLAMA_VOCAB_TYPE_UGM; + + // default special tokens + vocab.special_bos_id = -1; + vocab.special_eos_id = 1; + vocab.special_unk_id = 2; + vocab.special_sep_id = -1; + vocab.special_pad_id = 0; + vocab.special_cls_id = -1; + vocab.special_mask_id = -1; + + const int add_space_prefix_keyidx = lm_gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); + if (add_space_prefix_keyidx != -1) { + vocab.tokenizer_add_space_prefix = lm_gguf_get_val_bool(ctx, add_space_prefix_keyidx); + } // The default value of add_space_prefix is true. + + const int remove_extra_whitespaces_keyidx = lm_gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS).c_str()); + if (remove_extra_whitespaces_keyidx != -1) { + vocab.tokenizer_remove_extra_whitespaces = lm_gguf_get_val_bool(ctx, remove_extra_whitespaces_keyidx); + } // The default value of remove_extra_whitespaces is false. + + const int precompiled_charsmap_keyidx = lm_gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); + if (precompiled_charsmap_keyidx != -1) { + size_t n_precompiled_charsmap = lm_gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); + const char * precompiled_charsmap = (const char *) lm_gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); + vocab.precompiled_charsmap.assign(precompiled_charsmap, precompiled_charsmap + n_precompiled_charsmap); +#ifdef IS_BIG_ENDIAN + // correct endiannes of data in precompiled_charsmap binary blob + uint32_t * xcda_blob_size = (uint32_t *) &vocab.precompiled_charsmap[0]; + *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); + assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap); + size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t); + uint32_t * xcda_array = (uint32_t *) &vocab.precompiled_charsmap[sizeof(uint32_t)]; + for (size_t i = 0; i < xcda_array_size; ++i) { + xcda_array[i] = __builtin_bswap32(xcda_array[i]); + } +#endif + } + } else { + throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } // for now, only BPE models have pre-tokenizers if (vocab.type == LLAMA_VOCAB_TYPE_BPE) { + vocab.tokenizer_add_space_prefix = false; + vocab.tokenizer_clean_spaces = true; if (tokenizer_pre.empty()) { LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); LLAMA_LOG_WARN("%s: \n", __func__); @@ -4663,20 +5462,23 @@ static void llm_load_vocab( LLAMA_LOG_WARN("%s: ************************************ \n", __func__); LLAMA_LOG_WARN("%s: \n", __func__); vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - } else if ( - tokenizer_pre == "default") { + } else if (tokenizer_pre == "default") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if ( tokenizer_pre == "llama3" || tokenizer_pre == "llama-v3" || tokenizer_pre == "llama-bpe") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3; + vocab.tokenizer_ignore_merges = true; + vocab.tokenizer_add_bos = true; } else if ( tokenizer_pre == "deepseek-llm") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM; + vocab.tokenizer_clean_spaces = false; } else if ( tokenizer_pre == "deepseek-coder") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; + vocab.tokenizer_clean_spaces = false; } else if ( tokenizer_pre == "falcon") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON; @@ -4688,10 +5490,12 @@ static void llm_load_vocab( vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER; } else if ( tokenizer_pre == "gpt-2" || + tokenizer_pre == "phi-2" || tokenizer_pre == "jina-es" || tokenizer_pre == "jina-de" || tokenizer_pre == "jina-v2-es" || - tokenizer_pre == "jina-v2-de") { + tokenizer_pre == "jina-v2-de" || + tokenizer_pre == "jina-v2-code") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "refact") { @@ -4699,9 +5503,11 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "command-r") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; + vocab.tokenizer_clean_spaces = false; } else if ( tokenizer_pre == "qwen2") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; + vocab.tokenizer_clean_spaces = false; } else if ( tokenizer_pre == "stablelm2") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STABLELM2; @@ -4714,12 +5520,48 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "smaug-bpe") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG; + } else if ( + tokenizer_pre == "poro-chat") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "chatglm-bpe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; + vocab.special_bos_id = -1; + } else if ( + tokenizer_pre == "viking") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "jais") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } + } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.tokenizer_add_space_prefix = true; + vocab.tokenizer_clean_spaces = false; + vocab.tokenizer_add_bos = true; + vocab.tokenizer_add_eos = false; + } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.tokenizer_add_space_prefix = false; + vocab.tokenizer_clean_spaces = true; + vocab.tokenizer_add_bos = true; + vocab.tokenizer_add_eos = false; + } else if (vocab.type == LLAMA_VOCAB_TYPE_UGM) { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.tokenizer_add_bos = false; + vocab.tokenizer_add_eos = true; } else { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } + + const int add_space_prefix_keyidx = lm_gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); + if (add_space_prefix_keyidx != -1) { + vocab.tokenizer_add_space_prefix = lm_gguf_get_val_bool(ctx, add_space_prefix_keyidx); + } } const int token_idx = lm_gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); @@ -4748,16 +5590,68 @@ static void llm_load_vocab( LM_GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); vocab.token_to_id[word] = i; + vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size()); auto & token_data = vocab.id_to_token[i]; token_data.text = std::move(word); token_data.score = scores ? scores[i] : 0.0f; - token_data.type = toktypes ? (llama_token_type) toktypes[i] : LLAMA_TOKEN_TYPE_NORMAL; + token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; + + if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file + switch(toktypes[i]) { + case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN; break; + case LLAMA_TOKEN_TYPE_UNUSED: token_data.attr = LLAMA_TOKEN_ATTR_UNUSED; break; + case LLAMA_TOKEN_TYPE_NORMAL: token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; break; + case LLAMA_TOKEN_TYPE_CONTROL: token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; break; + case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break; + case LLAMA_TOKEN_TYPE_BYTE: token_data.attr = LLAMA_TOKEN_ATTR_BYTE; break; + case LLAMA_TOKEN_TYPE_UNDEFINED: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + } + } } LM_GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { + // For Fill-In-the-Middle (FIM)/infill models which where converted + // prior to support of FIM special tokens in GGUF, the following + // will allow those models to continue to work. The general names + // of the known models are currently CodeLlama (LLM_ARCH_LLAMA) and + // CodeGemma (LLM_ARCH_GEMMA). This can potentially be removed once + // new versions of these models have been published. + std::string gen_name; + ml.get_key(LLM_KV_GENERAL_NAME, gen_name, false); + + std::transform(gen_name.begin(), gen_name.end(), gen_name.begin(), + [](unsigned char c){ return std::tolower(c); }); + + if (gen_name.find("code") != std::string::npos) { + if (model.arch == LLM_ARCH_LLAMA + && 32010 < vocab.id_to_token.size() + && vocab.id_to_token[32007].text.find("
") != std::string::npos
+              && vocab.id_to_token[32008].text.find("") != std::string::npos
+              && vocab.id_to_token[32009].text.find("") != std::string::npos
+              && vocab.id_to_token[32010].text.find("") != std::string::npos) {
+                vocab.special_prefix_id = 32007;
+                vocab.special_suffix_id = 32008;
+                vocab.special_middle_id = 32009;
+                vocab.special_eot_id    = 32010;
+            } else if (model.arch == LLM_ARCH_GEMMA
+              && 107 < vocab.id_to_token.size()
+              && vocab.id_to_token[67].text == "<|fim_prefix|>"
+              && vocab.id_to_token[69].text == "<|fim_suffix|>"
+              && vocab.id_to_token[68].text == "<|fim_middle|>"
+              && vocab.id_to_token[107].text == "") {
+                vocab.special_prefix_id = 67;
+                vocab.special_suffix_id = 69;
+                vocab.special_middle_id = 68;
+                // TODO: this is not EOT, it is "file separator" token, needs fix
+                //       https://huggingface.co/google/codegemma-7b-it/blob/9b1d9231388358c04d90bd003458f5070d97db44/tokenizer_config.json#L565-L572
+                //vocab.special_eot_id    = 70;
+                vocab.special_eot_id    = 107;
+            }
+        }
         try {
             vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
         } catch (const std::exception & e) {
@@ -4809,10 +5703,10 @@ static void llm_load_vocab(
             bool temp = true;
 
             if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
-                vocab.special_add_bos = int(temp);
+                vocab.tokenizer_add_bos = temp;
             }
             if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
-                vocab.special_add_eos = int(temp);
+                vocab.tokenizer_add_eos = temp;
             }
         }
 
@@ -4843,18 +5737,88 @@ static void llm_load_vocab(
     // build special tokens cache
     {
         for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
-            if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
-                vocab.special_tokens_cache.push_back(id);
+            if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
+                vocab.cache_special_tokens.push_back(id);
             }
         }
 
-        std::sort( vocab.special_tokens_cache.begin(), vocab.special_tokens_cache.end(),
+        std::sort(vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
             [&] (const llama_vocab::id a, const llama_vocab::id b) {
                 return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
             }
         );
 
-        LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size());
+        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
+    }
+
+    // build token to piece cache
+    {
+        size_t size_cache = 0;
+
+        std::vector cache_token_to_piece(n_vocab);
+
+        for (uint32_t id = 0; id < n_vocab; ++id) {
+            cache_token_to_piece[id] = llama_token_to_piece(&model, id, true);
+
+            size_cache += cache_token_to_piece[id].size();
+        }
+
+        std::swap(vocab.cache_token_to_piece, cache_token_to_piece);
+
+        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
+    }
+
+    // Handle per token attributes
+    //NOTE: Each model customizes per token attributes.
+    //NOTE: Per token attributes are missing from the GGUF file.
+    //TODO: Extract attributes from GGUF file.
+    {
+        auto _contains_any = [] (const std::string &str, const std::vector &substrs) -> bool {
+            for (auto substr : substrs) {
+                if (str.find(substr) < std::string::npos) {
+                    return true;
+                }
+            }
+            return false;
+        };
+
+        auto _set_tokenid_attr = [&] (const llama_vocab::id id, llama_token_attr attr, bool value) {
+            uint32_t current = vocab.id_to_token.at(id).attr;
+            current = value ? (current | attr) : (current & ~attr);
+            vocab.id_to_token[id].attr = (llama_token_attr) current;
+        };
+
+        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
+            _set_tokenid_attr(vocab.token_to_id.at(token), attr, value);
+        };
+
+        std::string model_name;
+        std::string tokenizer_pre;
+
+        ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
+        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
+
+        // model name to lowercase
+        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
+            [] (const std::string::value_type x) {
+                return std::tolower(x);
+            }
+        );
+
+        // set attributes by model/tokenizer name
+        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
+            _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
+        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
+            for (auto id : vocab.cache_special_tokens) {
+                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (auto token : {""}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (auto token : {"", "", "<|endoftext|>"}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
+            }
+        }
     }
 }
 
@@ -4864,43 +5828,78 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
 
     const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
 
+    auto print_f = [](const std::function & f, uint32_t n) {
+        bool is_var = false;
+
+        std::vector v;
+        for (uint32_t i = 0; i < n; ++i) {
+            v.push_back(f(i));
+            if (v[i] != v[0]) {
+                is_var = true;
+            }
+        }
+
+        std::stringstream ss;
+
+        if (is_var) {
+            ss << "[";
+            for (uint32_t i = 0; i < n; ++i) {
+                ss << v[i];
+                if (i < n - 1) {
+                    ss << ", ";
+                }
+            }
+            ss << "]";
+        } else {
+            ss << v[0];
+        }
+
+        return ss.str();
+    };
+
     // hparams
     LLAMA_LOG_INFO("%s: format           = %s\n",     __func__, llama_file_version_name(ml.fver));
     LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, LLM_ARCH_NAMES.at(model.arch));
     LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, llama_model_vocab_type_name(vocab.type));
     LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, hparams.n_vocab);
     LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (int) vocab.bpe_ranks.size());
-    LLAMA_LOG_INFO("%s: n_ctx_train      = %u\n",     __func__, hparams.n_ctx_train);
-    LLAMA_LOG_INFO("%s: n_embd           = %u\n",     __func__, hparams.n_embd);
-    LLAMA_LOG_INFO("%s: n_head           = %u\n",     __func__, hparams.n_head);
-    LLAMA_LOG_INFO("%s: n_head_kv        = %u\n",     __func__, hparams.n_head_kv);
-    LLAMA_LOG_INFO("%s: n_layer          = %u\n",     __func__, hparams.n_layer);
-    LLAMA_LOG_INFO("%s: n_rot            = %u\n",     __func__, hparams.n_rot);
-    LLAMA_LOG_INFO("%s: n_embd_head_k    = %u\n",     __func__, hparams.n_embd_head_k);
-    LLAMA_LOG_INFO("%s: n_embd_head_v    = %u\n",     __func__, hparams.n_embd_head_v);
-    LLAMA_LOG_INFO("%s: n_gqa            = %u\n",     __func__, hparams.n_gqa());
-    LLAMA_LOG_INFO("%s: n_embd_k_gqa     = %u\n",     __func__, hparams.n_embd_k_gqa());
-    LLAMA_LOG_INFO("%s: n_embd_v_gqa     = %u\n",     __func__, hparams.n_embd_v_gqa());
-    LLAMA_LOG_INFO("%s: f_norm_eps       = %.1e\n",   __func__, hparams.f_norm_eps);
-    LLAMA_LOG_INFO("%s: f_norm_rms_eps   = %.1e\n",   __func__, hparams.f_norm_rms_eps);
-    LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
-    LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
-    LLAMA_LOG_INFO("%s: f_logit_scale    = %.1e\n",   __func__, hparams.f_logit_scale);
-    LLAMA_LOG_INFO("%s: n_ff             = %u\n",     __func__, hparams.n_ff);
-    LLAMA_LOG_INFO("%s: n_expert         = %u\n",     __func__, hparams.n_expert);
-    LLAMA_LOG_INFO("%s: n_expert_used    = %u\n",     __func__, hparams.n_expert_used);
-    LLAMA_LOG_INFO("%s: causal attn      = %d\n",     __func__, hparams.causal_attn);
-    LLAMA_LOG_INFO("%s: pooling type     = %d\n",     __func__, hparams.pooling_type);
-    LLAMA_LOG_INFO("%s: rope type        = %d\n",     __func__, hparams.rope_type);
-    LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type);
-    LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
-    LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
-    LLAMA_LOG_INFO("%s: n_yarn_orig_ctx  = %u\n",     __func__, hparams.n_yarn_orig_ctx);
-    LLAMA_LOG_INFO("%s: rope_finetuned   = %s\n",     __func__, hparams.rope_finetuned ? "yes" : "unknown");
-    LLAMA_LOG_INFO("%s: ssm_d_conv       = %u\n",     __func__, hparams.ssm_d_conv);
-    LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
-    LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
-    LLAMA_LOG_INFO("%s: ssm_dt_rank      = %u\n",     __func__, hparams.ssm_dt_rank);
+    LLAMA_LOG_INFO("%s: vocab_only       = %d\n",     __func__, hparams.vocab_only);
+
+    if (!hparams.vocab_only) {
+        LLAMA_LOG_INFO("%s: n_ctx_train      = %u\n",     __func__, hparams.n_ctx_train);
+        LLAMA_LOG_INFO("%s: n_embd           = %u\n",     __func__, hparams.n_embd);
+        LLAMA_LOG_INFO("%s: n_layer          = %u\n",     __func__, hparams.n_layer);
+        LLAMA_LOG_INFO("%s: n_head           = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head(il);    }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_head_kv        = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_rot            = %u\n",     __func__, hparams.n_rot);
+        LLAMA_LOG_INFO("%s: n_swa            = %u\n",     __func__, hparams.n_swa);
+        LLAMA_LOG_INFO("%s: n_embd_head_k    = %u\n",     __func__, hparams.n_embd_head_k);
+        LLAMA_LOG_INFO("%s: n_embd_head_v    = %u\n",     __func__, hparams.n_embd_head_v);
+        LLAMA_LOG_INFO("%s: n_gqa            = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il);        }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_embd_k_gqa     = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_embd_v_gqa     = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: f_norm_eps       = %.1e\n",   __func__, hparams.f_norm_eps);
+        LLAMA_LOG_INFO("%s: f_norm_rms_eps   = %.1e\n",   __func__, hparams.f_norm_rms_eps);
+        LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
+        LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
+        LLAMA_LOG_INFO("%s: f_logit_scale    = %.1e\n",   __func__, hparams.f_logit_scale);
+        LLAMA_LOG_INFO("%s: n_ff             = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_expert         = %u\n",     __func__, hparams.n_expert);
+        LLAMA_LOG_INFO("%s: n_expert_used    = %u\n",     __func__, hparams.n_expert_used);
+        LLAMA_LOG_INFO("%s: causal attn      = %d\n",     __func__, hparams.causal_attn);
+        LLAMA_LOG_INFO("%s: pooling type     = %d\n",     __func__, hparams.pooling_type);
+        LLAMA_LOG_INFO("%s: rope type        = %d\n",     __func__, hparams.rope_type);
+        LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type);
+        LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
+        LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
+        LLAMA_LOG_INFO("%s: n_ctx_orig_yarn  = %u\n",     __func__, hparams.n_ctx_orig_yarn);
+        LLAMA_LOG_INFO("%s: rope_finetuned   = %s\n",     __func__, hparams.rope_finetuned ? "yes" : "unknown");
+        LLAMA_LOG_INFO("%s: ssm_d_conv       = %u\n",     __func__, hparams.ssm_d_conv);
+        LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
+        LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
+        LLAMA_LOG_INFO("%s: ssm_dt_rank      = %u\n",     __func__, hparams.ssm_dt_rank);
+    }
+
     LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model.type));
     LLAMA_LOG_INFO("%s: model ftype      = %s\n",     __func__, llama_model_ftype_name(model.ftype).c_str());
     if (ml.n_elements >= 1e12) {
@@ -4936,6 +5935,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token        = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
     if (vocab.special_eot_id    != -1) { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,    vocab.id_to_token[vocab.special_eot_id].text.c_str() );    }
 
+    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
+
     if (model.arch == LLM_ARCH_DEEPSEEK2) {
         LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
         LLAMA_LOG_INFO("%s: n_lora_q             = %d\n",     __func__, hparams.n_lora_q);
@@ -4945,6 +5946,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n",   __func__, hparams.expert_weights_scale);
         LLAMA_LOG_INFO("%s: rope_yarn_log_mul    = %.4f\n",   __func__, hparams.rope_yarn_log_mul);
     }
+
+    if (model.arch == LLM_ARCH_QWEN2MOE) {
+        LLAMA_LOG_INFO("%s: n_ff_exp         = %d\n",     __func__, hparams.n_ff_exp);
+        LLAMA_LOG_INFO("%s: n_ff_shexp       = %d\n",     __func__, hparams.n_ff_shexp);
+    }
 }
 
 // Returns false if cancelled by progress_callback
@@ -4962,19 +5968,12 @@ static bool llm_load_tensors(
 
     auto & hparams = model.hparams;
 
-#ifdef LM_GGML_USE_SYCL
-    // disable MoE with SYCL until mul_mat_id is updated
-    if (hparams.n_expert > 0) {
-        n_gpu_layers = 0;
-    }
-#endif
-
     model.split_mode   = split_mode;
     model.main_gpu     = main_gpu;
     model.n_gpu_layers = n_gpu_layers;
 
-    const int64_t n_layer     = hparams.n_layer;
-    const int64_t i_gpu_start = std::max((int64_t) hparams.n_layer - n_gpu_layers, (int64_t) 0);
+    const int n_layer     = hparams.n_layer;
+    const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
     bool use_mmap_buffer = true;
 
     // there is very little benefit to offloading the input layer, so always keep it on the CPU
@@ -4984,7 +5983,7 @@ static bool llm_load_tensors(
     model.buft_layer.resize(n_layer);
 
     // assign cpu layers
-    for (int64_t i = 0; i < i_gpu_start; ++i) {
+    for (int i = 0; i < i_gpu_start; ++i) {
         model.buft_layer[i] = llama_default_buffer_type_cpu(true);
     }
 
@@ -5014,7 +6013,7 @@ static bool llm_load_tensors(
 
         // assign the repeating layers to the devices according to the splits
         int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1);
-        for (int64_t i = i_gpu_start; i < n_layer; ++i) {
+        for (int i = i_gpu_start; i < n_layer; ++i) {
             int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits.begin();
             model.buft_layer[i] = llama_default_buffer_type_offload(model, layer_gpu);
         }
@@ -5034,7 +6033,7 @@ static bool llm_load_tensors(
             split_buft = llama_default_buffer_type_offload(model, main_gpu);
         }
         // assign the repeating layers
-        for (int64_t i = i_gpu_start; i < n_layer; ++i) {
+        for (int i = i_gpu_start; i < n_layer; ++i) {
             model.buft_layer[i] = {
                 split_buft,
                 llama_default_buffer_type_offload(model, main_gpu)
@@ -5057,7 +6056,7 @@ static bool llm_load_tensors(
     buft_layer_count[model.buft_input.buft_matrix]++;
     buft_layer_count[model.buft_output.buft]++;
     buft_layer_count[model.buft_output.buft_matrix]++;
-    for (int64_t i = 0; i < n_layer; ++i) {
+    for (int i = 0; i < n_layer; ++i) {
         buft_layer_count[model.buft_layer[i].buft]++;
         buft_layer_count[model.buft_layer[i].buft_matrix]++;
     }
@@ -5087,15 +6086,21 @@ static bool llm_load_tensors(
 
     // create tensors for the weights
     {
-        const int64_t n_embd       = hparams.n_embd;
-        const int64_t n_embd_head  = n_embd / hparams.n_head;
-        const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
-        const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
-        const int64_t n_embd_gqa   = n_embd_v_gqa;
-        const int64_t n_vocab      = hparams.n_vocab;
-        const int64_t n_vocab_type = hparams.n_vocab_type;
-        const int64_t n_ff         = hparams.n_ff;
-        const int64_t n_expert     = hparams.n_expert;
+        // note: cast to int64_t since we will use these for the tensor dimensions
+        const int64_t n_head        = hparams.n_head();
+        const int64_t n_head_kv     = hparams.n_head_kv();
+        const int64_t n_embd        = hparams.n_embd;
+        const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
+        const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa();
+        const int64_t n_embd_head_k = hparams.n_embd_head_k;
+        const int64_t n_embd_head_v = hparams.n_embd_head_v;
+        const int64_t n_ff          = hparams.n_ff();
+        const int64_t n_embd_gqa    = n_embd_v_gqa;
+        const int64_t n_vocab       = hparams.n_vocab;
+        const int64_t n_vocab_type  = hparams.n_vocab_type;
+        const int64_t n_expert      = hparams.n_expert;
+        const int64_t n_expert_used = hparams.n_expert_used;
+        const int64_t n_ctx_train   = hparams.n_ctx_train;
 
         if (n_expert > 0 && hparams.n_expert_used == 0) {
             throw std::runtime_error("model has expert layers but no expert layers are used");
@@ -5104,8 +6109,9 @@ static bool llm_load_tensors(
         lm_ggml_context * ctx_input        = ctx_map.at(model.buft_input.buft);
         lm_ggml_context * ctx_output       = ctx_map.at(model.buft_output.buft);
         lm_ggml_context * ctx_output_split = ctx_map.at(model.buft_output.buft_matrix);
-        auto ctx_for_layer              = [&](int i) { return ctx_map.at(model.buft_layer[i].buft); };
-        auto ctx_for_layer_split        = [&](int i) { return ctx_map.at(model.buft_layer[i].buft_matrix); };
+
+        auto ctx_for_layer       = [&](int i) { return ctx_map.at(model.buft_layer[i].buft); };
+        auto ctx_for_layer_split = [&](int i) { return ctx_map.at(model.buft_layer[i].buft_matrix); };
 
         model.layers.resize(n_layer);
 
@@ -5120,12 +6126,11 @@ static bool llm_load_tensors(
                     // output
                     {
                         model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        if (model.arch != LLM_ARCH_MINICPM){
-                            model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            // if output is NULL, init from the input tok embed
-                            if (model.output == NULL) {
-                                model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                            }
+                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        // if output is NULL, init from the input tok embed
+                        if (model.output == NULL) {
+                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
                         }
                     }
 
@@ -5205,6 +6210,7 @@ static bool llm_load_tensors(
                     {
                         model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
                         model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
                         // if output is NULL, init from the input tok embed
                         if (model.output == NULL) {
                             model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
@@ -5228,9 +6234,9 @@ static bool llm_load_tensors(
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
 
-                        layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
-
+                        layer.ffn_gate_inp  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert});
                         layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
                         if (layer.ffn_gate_exps) {
                             layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert});
                             layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert});
@@ -5282,12 +6288,12 @@ static bool llm_load_tensors(
 
                     auto & layer = model.layers[i];
 
-                    layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,  "weight", i), {n_embd});
+                    layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
 
                     layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
                     layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
 
-                    layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
+                    layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
 
                     layer.ffn_gate_inp  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert});
                     layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert});
@@ -5329,10 +6335,10 @@ static bool llm_load_tensors(
 
                     // output
                     {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
+                        model.output_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
 
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
                         if (!model.output) {
                             model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
                         }
@@ -5360,7 +6366,7 @@ static bool llm_load_tensors(
             case LLM_ARCH_STARCODER:
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, hparams.n_ctx_train});
+                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train});
 
                     // output
                     {
@@ -5386,8 +6392,8 @@ static bool llm_load_tensors(
                         layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
                         layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
 
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
 
                         layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
                         layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
@@ -5395,8 +6401,8 @@ static bool llm_load_tensors(
                         layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
                         layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
 
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff});
+                        layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff});
                     }
                 } break;
             case LLM_ARCH_BERT:
@@ -5404,8 +6410,9 @@ static bool llm_load_tensors(
                 {
                     model.tok_embd     = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab});
                     model.type_embd    = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type});
+
                     if (model.arch == LLM_ARCH_BERT) {
-                        model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,    "weight"), {n_embd, hparams.n_ctx_train});
+                        model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,    "weight"), {n_embd, n_ctx_train});
                     }
 
                     model.tok_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
@@ -5418,33 +6425,32 @@ static bool llm_load_tensors(
                         auto & layer = model.layers[i];
 
                         if (model.arch == LLM_ARCH_BERT) {
-                            layer.wq   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                            layer.bq   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i),   {n_embd});
+                            layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                            layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i),   {n_embd});
 
-                            layer.wk   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                            layer.bk   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i),   {n_embd_gqa});
+                            layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                            layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i),   {n_embd_gqa});
 
-                            layer.wv   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                            layer.bv   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i),   {n_embd_gqa});
+                            layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                            layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i),   {n_embd_gqa});
                         } else {
                             layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
                         }
 
-                        layer.wo              = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd});
 
                         layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
                         layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i),   {n_embd});
 
-                        layer.ffn_up          = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff});
-                        layer.ffn_down        = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN,      "weight", i), {n_ff, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff});
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN,      "weight", i), {n_ff, n_embd});
 
                         if (model.arch == LLM_ARCH_BERT) {
-                            layer.bo         = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
-                            layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
-
-                            layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+                            layer.bo         = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
+                            layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff});
+                            layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
                         } else {
-                            layer.ffn_gate   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
+                            layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
                         }
 
                         layer.layer_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
@@ -5453,8 +6459,9 @@ static bool llm_load_tensors(
                 } break;
             case LLM_ARCH_JINA_BERT_V2:
                 {
-                    model.tok_embd     = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}); // word_embeddings
-                    model.type_embd    = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); //token_type_embeddings
+                    model.tok_embd  = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}); // word_embeddings
+                    model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); // token_type_embeddings
+
                     model.tok_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm
                     model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}); //LayerNorm bias
 
@@ -5464,35 +6471,38 @@ static bool llm_load_tensors(
 
                         auto & layer = model.layers[i]; // JinaBertLayer
 
-                        layer.wq   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.bq   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i),   {n_embd});
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
+                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd});
 
                         layer.attn_q_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.wk   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.bk   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i),   {n_embd_gqa});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
+                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias",   i), {n_embd_gqa});
 
                         layer.attn_k_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.wv   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.bv   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i),   {n_embd_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
+                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias",   i), {n_embd_gqa});
 
-                        layer.wo              = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd}); //output_dens
-                        layer.bo              = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "bias", i), {n_embd}); //output_dens
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); //output_dens
+                        layer.bo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd}); //output_dens
 
                         layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm
-                        layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i),   {n_embd});
+                        layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias",   i), {n_embd});
+
+                        layer.attn_norm_2   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff});
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE,    "weight", i), {n_embd, n_ff});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
+                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
 
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN,        "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN,      "bias", i), {n_embd});
+                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd});
 
-                        layer.layer_out_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_NORM,        "weight", i), {n_embd});
-                        layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM,        "bias", i), {n_embd});
+                        layer.layer_out_norm   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
+                        layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias",   i), {n_embd});
                     }
                 } break;
             case LLM_ARCH_BLOOM:
@@ -5515,35 +6525,35 @@ static bool llm_load_tensors(
                         auto & layer = model.layers[i];
 
                         layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias",   i), {n_embd});
 
                         layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
+                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias",   i), {n_embd + 2*n_embd_gqa});
 
                         layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd});
 
                         layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias",   i), {n_embd});
 
                         layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd});
 
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
+                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
+                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias",   i), {n_ff});
                     }
                 } break;
             case LLM_ARCH_MPT:
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, hparams.n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                     // output
                     {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        model.output_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
                         if (!model.output) {
                             model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
                         }
@@ -5614,8 +6624,8 @@ static bool llm_load_tensors(
                         layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                         // optional q and k layernorms, present in StableLM 2 12B
-                        layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head},    llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                         // optional FFN norm, not present in StableLM 2 12B which uses parallel residual
                         layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
@@ -5726,20 +6736,23 @@ static bool llm_load_tensors(
 
                         layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
 
-                        LM_GGML_ASSERT(hparams.n_expert      > 0);
-                        LM_GGML_ASSERT(hparams.n_expert_used > 0);
+                        LM_GGML_ASSERT(n_expert      > 0);
+                        LM_GGML_ASSERT(n_expert_used > 0);
 
                         // MoE branch
-                        auto n_ff_exp = n_ff / hparams.n_expert_used;
+                        const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
+
                         layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert});
                         layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert});
                         layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert});
 
                         // Shared expert branch
+                        const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
+
                         layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd});
-                        layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up_shexp   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {    n_embd, n_ff_shexp});
+                        layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp,     n_embd});
+                        layer.ffn_up_shexp   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {    n_embd, n_ff_shexp});
                     }
                 } break;
             case LLM_ARCH_PHI2:
@@ -5789,6 +6802,8 @@ static bool llm_load_tensors(
                 } break;
             case LLM_ARCH_PHI3:
                 {
+                    const int64_t n_embd_head = n_embd / n_head;
+
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab });
 
                     // output
@@ -5798,8 +6813,8 @@ static bool llm_load_tensors(
                     }
 
                     for (int i = 0; i < n_layer; ++i) {
-                        lm_ggml_context* ctx_layer = ctx_for_layer(i);
-                        lm_ggml_context* ctx_split = ctx_for_layer_split(i);
+                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
+                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
 
                         auto & layer = model.layers[i];
 
@@ -5848,7 +6863,7 @@ static bool llm_load_tensors(
             case LLM_ARCH_GPT2:
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"),   {n_embd, hparams.n_ctx_train});
+                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train});
 
                     // output
                     {
@@ -5985,12 +7000,34 @@ static bool llm_load_tensors(
                     model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
                     model.output      = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
 
-                    const int64_t n_ff          = hparams.n_ff;
-                    const int64_t n_embd_head_k = hparams.n_embd_head_k;
-                    const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
-                    const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa();
+                    for (int i = 0; i < n_layer; ++i) {
+                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
+                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
+
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                    }
+                } break;
+            case LLM_ARCH_GEMMA2:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                    model.output      = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
 
-                    for (uint32_t i = 0; i < n_layer; ++i) {
+                    for (int i = 0; i < n_layer; ++i) {
                         lm_ggml_context * ctx_layer = ctx_for_layer(i);
                         lm_ggml_context * ctx_split = ctx_for_layer_split(i);
 
@@ -5998,15 +7035,17 @@ static bool llm_load_tensors(
 
                         layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * hparams.n_head});
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
                         layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
                         layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head, n_embd});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
+                        layer.attn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd});
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
                         layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
                         layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
                         layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd});
                     }
                 } break;
             case LLM_ARCH_STARCODER2:
@@ -6063,6 +7102,7 @@ static bool llm_load_tensors(
                     const int64_t d_inner = hparams.ssm_d_inner;
                     const int64_t d_state = hparams.ssm_d_state;
                     const int64_t dt_rank = hparams.ssm_dt_rank;
+
                     // only an expansion factor of 2 is supported for now
                     LM_GGML_ASSERT(2 * n_embd == d_inner);
 
@@ -6113,15 +7153,20 @@ static bool llm_load_tensors(
                         model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
                         model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
                     }
+
                     for (int i = 0; i < n_layer; ++i) {
                         lm_ggml_context * ctx_layer = ctx_for_layer(i);
                         lm_ggml_context * ctx_split = ctx_for_layer_split(i);
+
                         auto & layer = model.layers[i];
+
                         layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
                         layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
                         layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
                         layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
                         layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
                         layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
                         layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
@@ -6148,8 +7193,8 @@ static bool llm_load_tensors(
                         layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
 
                         if (n_layer >= 64){
-                            layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head});
-                            layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv});
+                            layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head});
+                            layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv});
                         }
 
                         layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
@@ -6185,15 +7230,49 @@ static bool llm_load_tensors(
                         layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
                         layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
 
-
                         layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
                         layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
                         layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
                     }
                 } break;
+            case LLM_ARCH_OPENELM:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        // init output from the input tok embed
+                        model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        const int64_t n_head      =   hparams.n_head(i);
+                        const int64_t n_head_qkv  = 2*hparams.n_head_kv(i) + n_head;
+                        const int64_t n_ff        =   hparams.n_ff(i);
+
+                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
+                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k});
+                        layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k});
+                        layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd});
+
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
+                    }
+                } break;
             case LLM_ARCH_GPTNEOX:
                 {
-                    model.tok_embd   = ml.create_tensor(ctx_input,  tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab});
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
                     // output
                     {
                         model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
@@ -6232,8 +7311,9 @@ static bool llm_load_tensors(
 
                     // output
                     {
-                        model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
                         // if output is NULL, init from the input tok embed
                         if (model.output == NULL) {
                             model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
@@ -6268,13 +7348,16 @@ static bool llm_load_tensors(
                 } break;
             case LLM_ARCH_DEEPSEEK2:
                 {
-                    bool is_lite = (hparams.n_layer == 27);
+                    const bool is_lite = (hparams.n_layer == 27);
+
+                    const int64_t n_embd_head_qk_rope = hparams.n_rot;
+                    const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+
+                    const int64_t q_lora_rank  = hparams.n_lora_q;
+                    const int64_t kv_lora_rank = hparams.n_lora_kv;
 
-                    const uint32_t n_embd_head_qk_rope = hparams.n_rot;
-                    const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
-                    const uint32_t q_lora_rank = hparams.n_lora_q;
-                    const uint32_t kv_lora_rank = hparams.n_lora_kv;
-                    const uint32_t n_ff_exp = hparams.n_ff_exp;
+                    const int64_t n_ff_exp        = hparams.n_ff_exp;
+                    const int64_t n_expert_shared = hparams.n_expert_shared;
 
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
 
@@ -6294,29 +7377,31 @@ static bool llm_load_tensors(
                         if (!is_lite) {
                             layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank});
                         }
+
                         layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank});
 
                         if (!is_lite) {
-                            layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A,   "weight", i), {n_embd, q_lora_rank});
-                            layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B,   "weight", i), {q_lora_rank, hparams.n_head * hparams.n_embd_head_k});
+                            layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank});
+                            layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k});
                         } else {
-                            layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
+                            layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa});
                         }
-                        layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA,   "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope});
-                        layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B,   "weight", i), {kv_lora_rank, hparams.n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {hparams.n_head * hparams.n_embd_head_v, n_embd});
+
+                        layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)});
+                        layer.wkv_b     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)});
+                        layer.wo        = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd});
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
 
-                        if ((uint32_t) i < hparams.n_layer_dense_lead) {
+                        if (i < (int) hparams.n_layer_dense_lead) {
                             layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
                             layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
                             layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
                         } else {
                             layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
 
-                            LM_GGML_ASSERT(hparams.n_expert      > 0);
-                            LM_GGML_ASSERT(hparams.n_expert_used > 0);
+                            LM_GGML_ASSERT(n_expert      > 0);
+                            LM_GGML_ASSERT(n_expert_used > 0);
 
                             // MoE branch
                             layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert});
@@ -6324,11 +7409,178 @@ static bool llm_load_tensors(
                             layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert});
 
                             // Shared expert branch
-                            layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd,   n_ff_exp * hparams.n_expert_shared});
-                            layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {  n_ff_exp * hparams.n_expert_shared, n_embd});
-                            layer.ffn_up_shexp   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd,   n_ff_exp * hparams.n_expert_shared});
+                            layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared});
+                            layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd});
+                            layer.ffn_up_shexp   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared});
+                        }
+                    }
+                } break;
+            case LLM_ARCH_BITNET:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
+                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm     = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,     "weight", i), {n_embd});
+                        layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd});
+
+                        layer.wq       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                        layer.wq_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "scale",  i), {1});
+                        layer.wk       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wk_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "scale",  i), {1});
+                        layer.wv       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wv_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "scale",  i), {1});
+                        layer.wo       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wo_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "scale",  i), {1});
+
+                        layer.ffn_norm     = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM,     "weight", i), {n_embd});
+                        layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff});
+
+                        layer.ffn_gate       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
+                        layer.ffn_gate_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "scale",  i), {1});
+                        layer.ffn_down       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                        layer.ffn_down_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "scale",  i), {1});
+                        layer.ffn_up         = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
+                        layer.ffn_up_scale   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "scale",  i), {1});
+                    }
+                } break;
+            case LLM_ARCH_T5:
+                {
+                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
+
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output_norm     = ml.create_tensor(ctx_output, tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd});
+
+                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        // if output is NULL, init from the input tok embed
+                        if (model.output == NULL) {
+                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
                         }
                     }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
+                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm_enc  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd});
+                        layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
+                        layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
+
+                        layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_up_enc   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff});
+
+                        layer.attn_norm  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_ATTN_NORM,  "weight", i), {n_embd});
+                        layer.attn_rel_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
+
+                        layer.attn_norm_cross  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM,  "weight", i), {n_embd});
+                        // this tensor seems to be unused in HF transformers implementation
+                        layer.attn_rel_b_cross = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        layer.wq_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wk_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wv_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
+                        layer.wo_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
+
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                    }
+                } break;
+            case LLM_ARCH_JAIS:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // Output
+                    {
+                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
+                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
+                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd});
+                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd});
+
+                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
+                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
+
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+
+                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+
+                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+
+                        layer.ffn_gate   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd, n_ff});
+                        layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE,   "bias", i),   {n_ff});
+
+                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
+                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
+                    }
+                } break;
+            case LLM_ARCH_CHATGLM:
+                {
+                    model.tok_embd   = ml.create_tensor(ctx_input,  tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        lm_ggml_context * ctx_layer = ctx_for_layer(i);
+                        lm_ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)});
+                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + (hparams.n_embd_head_k << 2)});
+
+                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+
+                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff * 2});
+
+                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                    }
                 } break;
             default:
                 throw std::runtime_error("unknown architecture");
@@ -6529,16 +7781,6 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
         }
 #endif
 
-#ifdef LM_GGML_USE_SYCL
-        if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
-            lm_ggml_backend_sycl_set_single_device_mode(params.main_gpu);
-            //SYCL use device index (0, 1, 2) directly, uer input device id, then convert to device index.
-            params.main_gpu = lm_ggml_backend_sycl_get_device_index(params.main_gpu);
-        } else {
-            lm_ggml_backend_sycl_set_mul_device_mode();
-        }
-#endif
-
         if (!llm_load_tensors(
             ml, model, params.n_gpu_layers, params.split_mode,  params.main_gpu, params.tensor_split, params.use_mlock,
             params.progress_callback, params.progress_callback_user_data
@@ -6564,6 +7806,7 @@ enum llm_ffn_op_type {
     LLM_FFN_GELU,
     LLM_FFN_RELU,
     LLM_FFN_RELU_SQR,
+    LLM_FFN_SWIGLU,
 };
 
 enum llm_ffn_gate_type {
@@ -6618,8 +7861,8 @@ static void llm_build_kv_store(
                     int64_t   il) {
     const int64_t n_ctx = cparams.n_ctx;
 
-    const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
-    const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+    const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+    const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
     LM_GGML_ASSERT(kv.size == n_ctx);
 
@@ -6650,6 +7893,58 @@ static void llm_build_kv_store(
     lm_ggml_build_forward_expand(graph, lm_ggml_cpy(ctx, v_cur, v_cache_view));
 }
 
+// do mat_mul, while optionally apply lora
+static struct lm_ggml_tensor * llm_build_lora_mm(
+        struct llama_context & lctx,
+         struct lm_ggml_context * ctx0,
+          struct lm_ggml_tensor * w,
+          struct lm_ggml_tensor * cur) {
+    struct lm_ggml_tensor * res = lm_ggml_mul_mat(ctx0, w, cur);
+    for (auto & it : lctx.lora_adapters) {
+        struct llama_lora_weight * lora = it.first->get_weight(w);
+        if (lora == nullptr) {
+            continue;
+        }
+        const float alpha = it.first->alpha;
+        const float rank  = (float) lora->b->ne[0];
+        const float scale = alpha ? it.second * alpha / rank : it.second;
+        struct lm_ggml_tensor * ab_cur = lm_ggml_mul_mat(
+            ctx0, lora->b,
+            lm_ggml_mul_mat(ctx0, lora->a, cur)
+        );
+        ab_cur = lm_ggml_scale(ctx0, ab_cur, scale);
+        res = lm_ggml_add(ctx0, res, ab_cur);
+    }
+    return res;
+}
+
+// do mat_mul_id, while optionally apply lora
+static struct lm_ggml_tensor * llm_build_lora_mm_id(
+        struct llama_context & lctx,
+         struct lm_ggml_context * ctx0,
+          struct lm_ggml_tensor * w,   // struct lm_ggml_tensor * as
+          struct lm_ggml_tensor * cur, // struct lm_ggml_tensor * b
+          struct lm_ggml_tensor * ids) {
+    struct lm_ggml_tensor * res = lm_ggml_mul_mat_id(ctx0, w, cur, ids);
+    for (auto & it : lctx.lora_adapters) {
+        struct llama_lora_weight * lora = it.first->get_weight(w);
+        if (lora == nullptr) {
+            continue;
+        }
+        const float alpha = it.first->alpha;
+        const float rank  = (float) lora->b->ne[0];
+        const float scale = alpha ? it.second * alpha / rank : it.second;
+        struct lm_ggml_tensor * ab_cur = lm_ggml_mul_mat_id(
+            ctx0, lora->b,
+            lm_ggml_mul_mat_id(ctx0, lora->a, cur, ids),
+            ids
+        );
+        ab_cur = lm_ggml_scale(ctx0, ab_cur, scale);
+        res = lm_ggml_add(ctx0, res, ab_cur);
+    }
+    return res;
+}
+
 static struct lm_ggml_tensor * llm_build_norm(
         struct lm_ggml_context * ctx,
          struct lm_ggml_tensor * cur,
@@ -6684,19 +7979,23 @@ static struct lm_ggml_tensor * llm_build_norm(
 
 static struct lm_ggml_tensor * llm_build_ffn(
         struct lm_ggml_context * ctx,
+       struct llama_context & lctx,
          struct lm_ggml_tensor * cur,
          struct lm_ggml_tensor * up,
          struct lm_ggml_tensor * up_b,
+         struct lm_ggml_tensor * up_s,
          struct lm_ggml_tensor * gate,
          struct lm_ggml_tensor * gate_b,
+         struct lm_ggml_tensor * gate_s,
          struct lm_ggml_tensor * down,
          struct lm_ggml_tensor * down_b,
+         struct lm_ggml_tensor * down_s,
          struct lm_ggml_tensor * act_scales,
             llm_ffn_op_type   type_op,
           llm_ffn_gate_type   type_gate,
          const llm_build_cb & cb,
                         int   il) {
-    struct lm_ggml_tensor * tmp = up ? lm_ggml_mul_mat(ctx, up, cur) : cur;
+    struct lm_ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, cur) : cur;
     cb(tmp, "ffn_up", il);
 
     if (up_b) {
@@ -6704,16 +8003,21 @@ static struct lm_ggml_tensor * llm_build_ffn(
         cb(tmp, "ffn_up_b", il);
     }
 
+    if (up_s) {
+        tmp = lm_ggml_mul(ctx, tmp, up_s);
+        cb(tmp, "ffn_up_s", il);
+    }
+
     if (gate) {
         switch (type_gate) {
             case LLM_FFN_SEQ:
                 {
-                    cur = lm_ggml_mul_mat(ctx, gate, tmp);
+                    cur = llm_build_lora_mm(lctx, ctx, gate, tmp);
                     cb(cur, "ffn_gate", il);
                 } break;
             case LLM_FFN_PAR:
                 {
-                    cur = lm_ggml_mul_mat(ctx, gate, cur);
+                    cur = llm_build_lora_mm(lctx, ctx, gate, cur);
                     cb(cur, "ffn_gate", il);
                 } break;
         }
@@ -6722,6 +8026,12 @@ static struct lm_ggml_tensor * llm_build_ffn(
             cur = lm_ggml_add(ctx, cur, gate_b);
             cb(cur, "ffn_gate_b", il);
         }
+
+        if (gate_s) {
+            cur = lm_ggml_mul(ctx, cur, gate_s);
+            cb(cur, "ffn_gate_s", il);
+        }
+
     } else {
         cur = tmp;
     }
@@ -6754,6 +8064,19 @@ static struct lm_ggml_tensor * llm_build_ffn(
                 cur = lm_ggml_sqr(ctx, cur);
                 cb(cur, "ffn_sqr(relu)", il);
             } break;
+        case LLM_FFN_SWIGLU:
+            {
+                // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
+                int64_t split_point = cur->ne[0] / 2;
+                struct lm_ggml_tensor * x0 = lm_ggml_cont(ctx, lm_ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], 0));
+                struct lm_ggml_tensor * x1 = lm_ggml_cont(ctx, lm_ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], split_point * lm_ggml_element_size(cur)));
+
+                x0 = lm_ggml_silu(ctx, x0);
+                cb(cur, "ffn_silu", il);
+
+                cur = lm_ggml_mul(ctx, x0, x1);
+                cb(cur, "ffn_mul", il);
+            } break;
     }
 
     if (type_gate == LLM_FFN_PAR) {
@@ -6761,7 +8084,10 @@ static struct lm_ggml_tensor * llm_build_ffn(
         cb(cur, "ffn_gate_par", il);
     }
 
-    cur = lm_ggml_mul_mat(ctx, down, cur);
+    if (down) {
+        cur = llm_build_lora_mm(lctx, ctx, down, cur);
+    }
+
     if (down_b) {
         cb(cur, "ffn_down", il);
     }
@@ -6770,11 +8096,17 @@ static struct lm_ggml_tensor * llm_build_ffn(
         cur = lm_ggml_add(ctx, cur, down_b);
     }
 
+    if (down_s) {
+        cur = lm_ggml_mul(ctx, cur, down_s);
+        cb(cur, "ffn_down_s", il);
+    }
+
     return cur;
 }
 
 static struct lm_ggml_tensor * llm_build_moe_ffn(
         struct lm_ggml_context * ctx,
+       struct llama_context & lctx,
          struct lm_ggml_tensor * cur,
          struct lm_ggml_tensor * gate_inp,
          struct lm_ggml_tensor * up_exps,
@@ -6791,7 +8123,7 @@ static struct lm_ggml_tensor * llm_build_moe_ffn(
     int64_t n_embd = cur->ne[0];
     int64_t n_tokens = cur->ne[1];
 
-    lm_ggml_tensor * logits = lm_ggml_mul_mat(ctx, gate_inp, cur); // [n_expert, n_tokens]
+    lm_ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
     cb(logits, "ffn_moe_logits", il);
 
     lm_ggml_tensor * probs = lm_ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
@@ -6823,10 +8155,10 @@ static struct lm_ggml_tensor * llm_build_moe_ffn(
     }
 
     cur = lm_ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
-    lm_ggml_tensor * up = lm_ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+    lm_ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
     cb(up, "ffn_moe_up", il);
 
-    lm_ggml_tensor * gate = lm_ggml_mul_mat_id(ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+    lm_ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
     cb(gate, "ffn_moe_gate", il);
 
     switch (type_op) {
@@ -6847,7 +8179,7 @@ static struct lm_ggml_tensor * llm_build_moe_ffn(
     lm_ggml_tensor * par = lm_ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens]
     cb(par, "ffn_moe_gate_par", il);
 
-    lm_ggml_tensor * experts = lm_ggml_mul_mat_id(ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
+    lm_ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
     cb(experts, "ffn_moe_down", il);
 
     experts = lm_ggml_mul(ctx, experts, weights);
@@ -6875,9 +8207,7 @@ static struct lm_ggml_tensor * llm_build_moe_ffn(
 
 static struct lm_ggml_tensor * llm_build_kqv(
         struct lm_ggml_context * ctx,
-          const llama_model & model,
-        const llama_hparams & hparams,
-        const llama_cparams & cparams,
+       struct llama_context & lctx,
        const llama_kv_cache & kv,
          struct lm_ggml_cgraph * graph,
          struct lm_ggml_tensor * wo,
@@ -6889,13 +8219,17 @@ static struct lm_ggml_tensor * llm_build_kqv(
                     float     kq_scale,
          const llm_build_cb & cb,
                     int       il) {
+    const llama_model   & model   = lctx.model;
+    const llama_hparams & hparams = lctx.model.hparams;
+    const llama_cparams & cparams = lctx.cparams;
+
     const int64_t n_ctx         = cparams.n_ctx;
-    const int64_t n_head        = hparams.n_head;
-    const int64_t n_head_kv     = hparams.n_head_kv;
+    const int64_t n_head        = hparams.n_head(il);
+    const int64_t n_head_kv     = hparams.n_head_kv(il);
     const int64_t n_embd_head_k = hparams.n_embd_head_k;
-    const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
+    const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa(il);
     const int64_t n_embd_head_v = hparams.n_embd_head_v;
-    const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa();
+    const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa(il);
 
     struct lm_ggml_tensor * q = lm_ggml_permute(ctx, q_cur, 0, 2, 1, 3);
     cb(q, "q", il);
@@ -6934,7 +8268,7 @@ static struct lm_ggml_tensor * llm_build_kqv(
         struct lm_ggml_tensor * kq = lm_ggml_mul_mat(ctx, k, q);
         cb(kq, "kq", il);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
             // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
             // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
             lm_ggml_mul_mat_set_prec(kq, LM_GGML_PREC_F32);
@@ -6954,6 +8288,12 @@ static struct lm_ggml_tensor * llm_build_kqv(
             kq = lm_ggml_scale(ctx, kq, 30);
         }
 
+        if (hparams.attn_soft_cap) {
+            kq = lm_ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
+            kq = lm_ggml_tanh(ctx, kq);
+            kq = lm_ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
+        }
+
         kq = lm_ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
         cb(kq, "kq_soft_max_ext", il);
 
@@ -6980,7 +8320,10 @@ static struct lm_ggml_tensor * llm_build_kqv(
 
     lm_ggml_build_forward_expand(graph, cur);
 
-    cur = lm_ggml_mul_mat(ctx, wo, cur);
+    if (wo) {
+        cur = llm_build_lora_mm(lctx, ctx, wo, cur);
+    }
+
     if (wo_b) {
         cb(cur, "kqv_wo", il);
     }
@@ -6994,9 +8337,7 @@ static struct lm_ggml_tensor * llm_build_kqv(
 
 static struct lm_ggml_tensor * llm_build_kv(
         struct lm_ggml_context * ctx,
-          const llama_model & model,
-        const llama_hparams & hparams,
-        const llama_cparams & cparams,
+       struct llama_context & lctx,
        const llama_kv_cache & kv,
          struct lm_ggml_cgraph * graph,
          struct lm_ggml_tensor * wo,
@@ -7011,6 +8352,8 @@ static struct lm_ggml_tensor * llm_build_kv(
                     float     kq_scale,
          const llm_build_cb & cb,
                     int       il) {
+    const llama_hparams & hparams = lctx.model.hparams;
+    const llama_cparams & cparams = lctx.cparams;
 
     // these nodes are added to the graph together so that they are not reordered
     // by doing so, the number of splits in the graph is reduced
@@ -7022,7 +8365,7 @@ static struct lm_ggml_tensor * llm_build_kv(
 
     struct lm_ggml_tensor * cur;
 
-    cur  = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b,
+    cur  = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b,
             q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
     cb(cur, "kqv_out", il);
 
@@ -7062,8 +8405,9 @@ struct llm_build_context {
     const int32_t n_tokens;
     const int32_t n_kv;     // size of KV cache to consider (n_kv <= kv_self.size)
     const int32_t n_outputs;
+    const int32_t n_outputs_enc;
     const int32_t kv_head;  // index of where we store new KV data in the cache
-    const int32_t n_orig_ctx;
+    const int32_t n_ctx_orig;
 
     const bool flash_attn;
 
@@ -7092,8 +8436,8 @@ struct llm_build_context {
         n_layer          (hparams.n_layer),
         n_rot            (hparams.n_rot),
         n_ctx            (cparams.n_ctx),
-        n_head           (hparams.n_head),
-        n_head_kv        (hparams.n_head_kv),
+        n_head           (hparams.n_head()),
+        n_head_kv        (hparams.n_head_kv()),
         n_embd_head_k    (hparams.n_embd_head_k),
         n_embd_k_gqa     (hparams.n_embd_k_gqa()),
         n_embd_head_v    (hparams.n_embd_head_v),
@@ -7111,8 +8455,9 @@ struct llm_build_context {
         n_tokens         (batch.n_tokens),
         n_kv             (worst_case ? kv_self.size : kv_self.n),
         n_outputs        (worst_case ? n_tokens : lctx.n_outputs),
+        n_outputs_enc    (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
         kv_head          (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
-        n_orig_ctx       (cparams.n_yarn_orig_ctx),
+        n_ctx_orig       (cparams.n_ctx_orig_yarn),
         flash_attn       (cparams.flash_attn),
         pooling_type     (cparams.pooling_type),
         rope_type        (hparams.rope_type),
@@ -7130,17 +8475,21 @@ struct llm_build_context {
 
         ctx0 = lm_ggml_init(params);
 
-        lctx.inp_tokens  = nullptr;
-        lctx.inp_embd    = nullptr;
-        lctx.inp_pos     = nullptr;
-        lctx.inp_out_ids = nullptr;
-        lctx.inp_KQ_mask = nullptr;
-        lctx.inp_K_shift = nullptr;
-        lctx.inp_mean    = nullptr;
-        lctx.inp_cls     = nullptr;
-        lctx.inp_s_copy  = nullptr;
-        lctx.inp_s_mask  = nullptr;
-        lctx.inp_s_seq   = nullptr;
+        lctx.inp_tokens      = nullptr;
+        lctx.inp_embd        = nullptr;
+        lctx.inp_pos         = nullptr;
+        lctx.inp_out_ids     = nullptr;
+        lctx.inp_KQ_mask     = nullptr;
+        lctx.inp_KQ_mask_swa = nullptr;
+        lctx.inp_K_shift     = nullptr;
+        lctx.inp_mean        = nullptr;
+        lctx.inp_cls         = nullptr;
+        lctx.inp_s_copy      = nullptr;
+        lctx.inp_s_mask      = nullptr;
+        lctx.inp_s_seq       = nullptr;
+        lctx.inp_pos_bucket    = nullptr;
+        lctx.inp_embd_enc      = nullptr;
+        lctx.inp_KQ_mask_cross = nullptr;
     }
 
     void free() {
@@ -7159,8 +8508,9 @@ struct llm_build_context {
         cb(lctx.inp_K_shift, "K_shift", -1);
         lm_ggml_set_input(lctx.inp_K_shift);
 
-
         for (int il = 0; il < n_layer; ++il) {
+            const int64_t n_head_kv = hparams.n_head_kv(il);
+            const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
             struct lm_ggml_tensor * rope_factors = build_rope_factors(il);
             struct lm_ggml_tensor * tmp =
                 // we rotate only the first n_rot dimensions
@@ -7170,7 +8520,7 @@ struct llm_build_context {
                             lm_ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
                             lm_ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
                             0),
-                        lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                        lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
 
             cb(tmp, "K_shifted", il);
@@ -7220,6 +8570,9 @@ struct llm_build_context {
             }
 
             for (int il = 0; il < n_layer; ++il) {
+                const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+                const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
                 lm_ggml_tensor * view_k_src = lm_ggml_view_2d(ctx0, kv_self.k_l[il],
                         n_embd_k_gqa, nm,
                         lm_ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
@@ -7279,7 +8632,7 @@ struct llm_build_context {
         // choose long/short freq factors based on the context size
         const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
 
-        if (n_ctx_pre_seq > hparams.n_yarn_orig_ctx) {
+        if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) {
             return model.layers[il].rope_long;
         }
 
@@ -7294,16 +8647,27 @@ struct llm_build_context {
     }
 
     struct lm_ggml_tensor * build_inp_KQ_mask(bool causal = true) {
-        if (causal) {
-            lctx.inp_KQ_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv,     LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
-        } else {
-            lctx.inp_KQ_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_tokens, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
-        }
+        lctx.inp_KQ_mask = causal
+            ? lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv,     LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD))
+            : lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_tokens, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
         cb(lctx.inp_KQ_mask, "KQ_mask", -1);
         lm_ggml_set_input(lctx.inp_KQ_mask);
+
         return flash_attn ? lm_ggml_cast(ctx0, lctx.inp_KQ_mask, LM_GGML_TYPE_F16) : lctx.inp_KQ_mask;
     }
 
+    struct lm_ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) {
+        LM_GGML_ASSERT(hparams.n_swa > 0);
+
+        lctx.inp_KQ_mask_swa = causal
+            ? lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv,     LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD))
+            : lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_tokens, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
+        cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
+        lm_ggml_set_input(lctx.inp_KQ_mask_swa);
+
+        return flash_attn ? lm_ggml_cast(ctx0, lctx.inp_KQ_mask_swa, LM_GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
+    }
+
     struct lm_ggml_tensor * build_inp_mean() {
         lctx.inp_mean = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_tokens, n_tokens);
         cb(lctx.inp_mean, "inp_mean", -1);
@@ -7339,6 +8703,97 @@ struct llm_build_context {
         return lctx.inp_s_seq;
     }
 
+    struct lm_ggml_cgraph * append_pooling(struct lm_ggml_cgraph * gf) {
+        // find result_norm tensor for input
+        struct lm_ggml_tensor * inp = nullptr;
+        for (int i = gf->n_nodes - 1; i >= 0; --i) {
+            inp = gf->nodes[i];
+            if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
+                break;
+            } else {
+                inp = nullptr;
+            }
+        }
+        LM_GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
+
+        struct lm_ggml_tensor * cur;
+
+        switch (pooling_type) {
+            case LLAMA_POOLING_TYPE_MEAN:
+                {
+                    struct lm_ggml_tensor * inp_mean = build_inp_mean();
+                    cur = lm_ggml_mul_mat(ctx0, lm_ggml_cont(ctx0, lm_ggml_transpose(ctx0, inp)), inp_mean);
+                } break;
+            case LLAMA_POOLING_TYPE_CLS:
+            case LLAMA_POOLING_TYPE_LAST:
+                {
+                    struct lm_ggml_tensor * inp_cls = build_inp_cls();
+                    cur = lm_ggml_get_rows(ctx0, inp, inp_cls);
+                } break;
+            case LLAMA_POOLING_TYPE_NONE:
+                {
+                    cur = inp;
+                } break;
+            default:
+                {
+                    LM_GGML_ASSERT(false && "unknown pooling type");
+                } break;
+        }
+
+        cb(cur, "result_embd_pooled", -1);
+
+        lm_ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    struct lm_ggml_tensor * llm_build_pos_bucket(bool causal) {
+        if (causal) {
+            lctx.inp_pos_bucket = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_I32, n_kv,     n_tokens);
+        } else {
+            lctx.inp_pos_bucket = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_I32, n_tokens, n_tokens);
+        }
+
+        lm_ggml_set_input(lctx.inp_pos_bucket);
+        cb(lctx.inp_pos_bucket, "pos_bucket", -1);
+
+        return lctx.inp_pos_bucket;
+    }
+
+    struct lm_ggml_tensor * llm_build_pos_bias(struct lm_ggml_tensor * pos_bucket, struct lm_ggml_tensor * attn_rel_b) {
+        struct lm_ggml_tensor * pos_bucket_1d = lm_ggml_view_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1], 0);
+        cb(pos_bucket_1d, "pos_bucket_1d", -1);
+
+        struct lm_ggml_tensor * pos_bias = lm_ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
+        cb(pos_bias, "pos_bias", -1);
+
+        pos_bias = lm_ggml_view_3d(ctx0, pos_bias, pos_bias->ne[0], lctx.inp_pos_bucket->ne[0], lctx.inp_pos_bucket->ne[1], lm_ggml_element_size(pos_bias) * pos_bias->ne[0], lm_ggml_element_size(pos_bias) * pos_bias->ne[0] * lctx.inp_pos_bucket->ne[0],  0);
+        cb(pos_bias, "pos_bias", -1);
+
+        pos_bias = lm_ggml_permute(ctx0, pos_bias, 2, 0, 1, 3);
+        cb(pos_bias, "pos_bias", -1);
+
+        pos_bias = lm_ggml_cont(ctx0, pos_bias);
+        cb(pos_bias, "pos_bias", -1);
+
+        return pos_bias;
+    }
+
+    struct lm_ggml_tensor * llm_build_inp_embd_enc() {
+        const int64_t n_embd = hparams.n_embd;
+        lctx.inp_embd_enc = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_embd, n_outputs_enc);
+        lm_ggml_set_input(lctx.inp_embd_enc);
+        cb(lctx.inp_embd_enc, "embd_enc", -1);
+        return lctx.inp_embd_enc;
+    }
+
+    struct lm_ggml_tensor * llm_build_inp_KQ_mask_cross() {
+        lctx.inp_KQ_mask_cross = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_outputs_enc, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
+        lm_ggml_set_input(lctx.inp_KQ_mask_cross);
+        cb(lctx.inp_KQ_mask_cross, "KQ_mask_cross", -1);
+        return lctx.inp_KQ_mask_cross;
+    }
+
     struct lm_ggml_cgraph * build_llama() {
         struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
@@ -7372,21 +8827,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -7395,19 +8850,19 @@ struct llm_build_context {
 
                 Qcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -7430,10 +8885,10 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
@@ -7444,7 +8899,7 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_moe_ffn(ctx0, cur,
+                cur = llm_build_moe_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_gate_inp,
                         model.layers[il].ffn_up_exps,
                         model.layers[il].ffn_gate_exps,
@@ -7459,10 +8914,7 @@ struct llm_build_context {
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
 
-            lm_ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-            if (layer_dir != nullptr) {
-                cur = lm_ggml_add(ctx0, cur, layer_dir);
-            }
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -7477,7 +8929,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -7513,25 +8965,25 @@ struct llm_build_context {
 
             // self-attention
             {
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 switch (model.type) {
                     case MODEL_7B:
                         Qcur = lm_ggml_rope_ext(
                             ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                            n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                             ext_factor, attn_factor, beta_fast, beta_slow
                         );
                         Kcur = lm_ggml_rope_ext(
                             ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                            n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                             ext_factor, attn_factor, beta_fast, beta_slow
                         );
                         break;
@@ -7545,7 +8997,7 @@ struct llm_build_context {
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -7567,16 +9019,17 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -7591,7 +9044,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -7627,29 +9080,29 @@ struct llm_build_context {
 
             // self-attention
             {
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -7671,16 +9124,17 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -7693,7 +9147,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -7742,7 +9196,7 @@ struct llm_build_context {
                     cur = attn_norm;
                 }
 
-                cur = lm_ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 struct lm_ggml_tensor * Qcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
@@ -7758,18 +9212,18 @@ struct llm_build_context {
 
                 // using mode = 2 for neox mode
                 Qcur = lm_ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
+                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
+                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -7786,19 +9240,18 @@ struct llm_build_context {
 
             // feed forward
             {
-                cur = llm_build_ffn(ctx0, attn_norm, // !! use the attn norm, not the result
-                        model.layers[il].ffn_up,   NULL,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, attn_norm, // !! use the attn norm, not the result
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        NULL,                      NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "l_out", il);
-
             cur = lm_ggml_add(ctx0, cur, inpL);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -7814,7 +9267,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -7859,21 +9312,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -7882,19 +9335,19 @@ struct llm_build_context {
 
                 Qcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
@@ -7926,7 +9379,7 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_moe_ffn(ctx0, cur,
+            cur = llm_build_moe_ffn(ctx0, lctx, cur,
                     model.layers[il].ffn_gate_inp,
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
@@ -7950,10 +9403,7 @@ struct llm_build_context {
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
 
-            lm_ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-            if (layer_dir != nullptr) {
-                cur = lm_ggml_add(ctx0, cur, layer_dir);
-            }
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -7968,7 +9418,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
 
         // Grok
         // multiply logits by output_multiplier_scale of 0.5773502691896257
@@ -8019,7 +9469,7 @@ struct llm_build_context {
                 struct lm_ggml_tensor * Kcur = nullptr;
                 struct lm_ggml_tensor * Vcur = nullptr;
 
-                cur = lm_ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = lm_ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
@@ -8035,19 +9485,19 @@ struct llm_build_context {
 
                 Qcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -8070,7 +9520,7 @@ struct llm_build_context {
                                  LLM_NORM, cb, il);
             cb(cur, "attn_out_norm", il);
 
-            cur = llm_build_moe_ffn(ctx0, cur,
+            cur = llm_build_moe_ffn(ctx0, lctx, cur,
                     model.layers[il].ffn_gate_inp,
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
@@ -8084,10 +9534,7 @@ struct llm_build_context {
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
 
-            lm_ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-            if (layer_dir != nullptr) {
-                cur = lm_ggml_add(ctx0, cur, layer_dir);
-            }
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -8102,7 +9549,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
 
         cb(cur, "result_output", -1);
 
@@ -8144,7 +9591,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = lm_ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = lm_ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -8160,7 +9607,7 @@ struct llm_build_context {
 
                 Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -8184,17 +9631,21 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
-            inpL = lm_ggml_add(ctx0, cur, ffn_inp);
-            cb(inpL, "l_out", il);
+            cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
         }
 
         cur = llm_build_norm(ctx0, inpL, hparams,
@@ -8203,7 +9654,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -8235,13 +9686,13 @@ struct llm_build_context {
 
             // self-attention
             {
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
@@ -8250,7 +9701,7 @@ struct llm_build_context {
                 Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 cb(Qcur, "Qcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -8272,16 +9723,17 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -8296,7 +9748,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -8319,8 +9771,6 @@ struct llm_build_context {
         if (model.arch != LLM_ARCH_JINA_BERT_V2) {
             inp_pos = build_inp_pos();
         }
-        struct lm_ggml_tensor * inp_mean = build_inp_mean();
-        struct lm_ggml_tensor * inp_cls  = build_inp_cls();
 
         // construct input embeddings (token, type, position)
         inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
@@ -8350,7 +9800,7 @@ struct llm_build_context {
 
             // self-attention
             if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
-                Qcur = lm_ggml_add(ctx0, lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
+                Qcur = lm_ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur), model.layers[il].bq);
                 cb(Qcur, "Qcur", il);
 
                 if (model.layers[il].attn_q_norm) {
@@ -8360,7 +9810,7 @@ struct llm_build_context {
                             LLM_NORM, cb, il);
                 }
 
-                Kcur = lm_ggml_add(ctx0, lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk);
+                Kcur = lm_ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur), model.layers[il].bk);
                 cb(Kcur, "Kcur", il);
 
                 if (model.layers[il].attn_k_norm) {
@@ -8369,14 +9819,14 @@ struct llm_build_context {
                             model.layers[il].attn_k_norm_b,
                             LLM_NORM, cb, il);
                 }
-                Vcur = lm_ggml_add(ctx0, lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv);
+                Vcur = lm_ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur), model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
             } else {
                 // compute Q and K and RoPE them
-                cur = lm_ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 Qcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
@@ -8389,14 +9839,14 @@ struct llm_build_context {
 
                 Qcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -8425,7 +9875,7 @@ struct llm_build_context {
 
             lm_ggml_build_forward_expand(gf, cur);
 
-            cur = lm_ggml_mul_mat(ctx0, model.layers[il].wo, cur);
+            cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
             if (model.layers[il].bo) {
                 cb(cur, "kqv_wo", il);
             }
@@ -8448,29 +9898,34 @@ struct llm_build_context {
             // attention layer norm
             cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_out_norm, model.layers[il].attn_out_norm_b, LLM_NORM, cb, il);
 
+            if (model.layers[il].attn_norm_2 != nullptr) {
+                cur = lm_ggml_add(ctx0, cur, inpL); // re-add the layer input
+                cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm_2, model.layers[il].attn_norm_2_b, LLM_NORM, cb, il);
+            }
+
             struct lm_ggml_tensor * ffn_inp = cur;
             cb(ffn_inp, "ffn_inp", il);
 
             // feed-forward network
             if (model.arch == LLM_ARCH_BERT) {
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
             } else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL,                        NULL,
+                        model.layers[il].ffn_gate, NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
             } else {
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
             }
@@ -8490,28 +9945,6 @@ struct llm_build_context {
         cur = inpL;
         cb(cur, "result_embd", -1);
 
-        // pooling layer
-        switch (pooling_type) {
-            case LLAMA_POOLING_TYPE_NONE:
-                {
-                    // nop
-                } break;
-            case LLAMA_POOLING_TYPE_MEAN:
-                {
-                    cur = lm_ggml_mul_mat(ctx0, lm_ggml_cont(ctx0, lm_ggml_transpose(ctx0, cur)), inp_mean);
-                    cb(cur, "result_embd_pooled", -1);
-                } break;
-            case LLAMA_POOLING_TYPE_CLS:
-                {
-                    cur = lm_ggml_get_rows(ctx0, cur, inp_cls);
-                    cb(cur, "result_embd_pooled", -1);
-                } break;
-            case LLAMA_POOLING_TYPE_UNSPECIFIED:
-                {
-                    LM_GGML_ASSERT(false && "Invalid pooling type");
-                } break;
-        }
-
         lm_ggml_build_forward_expand(gf, cur);
 
         return gf;
@@ -8547,7 +9980,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = lm_ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = lm_ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -8563,7 +9996,7 @@ struct llm_build_context {
 
                 Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -8587,17 +10020,21 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
-            inpL = lm_ggml_add(ctx0, cur, ffn_inp);
-            cb(inpL, "l_out", il);
+            cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
         }
 
         cur = llm_build_norm(ctx0, inpL, hparams,
@@ -8606,7 +10043,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -8653,7 +10090,7 @@ struct llm_build_context {
             {
                 cur = attn_norm;
 
-                cur = lm_ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 if (model.layers[il].bqkv){
@@ -8691,13 +10128,13 @@ struct llm_build_context {
                     Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                     Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
-                    cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                    cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                             model.layers[il].wo, model.layers[il].bo,
                             Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 } else {
                     Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                    cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                    cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                             model.layers[il].wo, model.layers[il].bo,
                             Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 }
@@ -8721,16 +10158,17 @@ struct llm_build_context {
                         model.layers[il].ffn_norm_b,
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         model.layers[il].ffn_act,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -8745,7 +10183,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -8785,21 +10223,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -8829,19 +10267,19 @@ struct llm_build_context {
 
                 Qcur = lm_ggml_rope_ext(
                     ctx0, Qcur, inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                     ctx0, Kcur, inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -8869,16 +10307,17 @@ struct llm_build_context {
                     // parallel residual
                     cur = inpSA;
                 }
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -8894,7 +10333,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -8929,7 +10368,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = lm_ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = lm_ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -8948,18 +10387,18 @@ struct llm_build_context {
 
                 // using mode = 2 for neox mode
                 Qcur = lm_ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
+                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
+                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -8981,16 +10420,17 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -9005,7 +10445,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -9043,36 +10483,36 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
                 cb(Qcur, "Qcur", il);
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
                 cb(Kcur, "Kcur", il);
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9093,15 +10533,16 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   NULL,
-                    model.layers[il].ffn_gate, NULL,
-                    model.layers[il].ffn_down, NULL,
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
                     NULL,
                     LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
             cb(cur, "ffn_out", il);
 
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -9116,7 +10557,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -9157,36 +10598,36 @@ struct llm_build_context {
             // self_attention
             {
                 // compute Q and K and RoPE them
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
                 cb(Qcur, "Qcur", il);
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
                 cb(Kcur, "Kcur", il);
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9209,7 +10650,7 @@ struct llm_build_context {
             cb(cur, "ffn_norm", il);
 
             lm_ggml_tensor * moe_out =
-                    llm_build_moe_ffn(ctx0, cur,
+                    llm_build_moe_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_gate_inp,
                         model.layers[il].ffn_up_exps,
                         model.layers[il].ffn_gate_exps,
@@ -9222,17 +10663,17 @@ struct llm_build_context {
 
             // FFN shared expert
             {
-                lm_ggml_tensor * cur_gate_inp = lm_ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp_shexp, cur);
+                lm_ggml_tensor * cur_gate_inp = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_gate_inp_shexp, cur);
                 cb(cur_gate_inp, "ffn_shexp_gate_inp", il);
 
                 // sigmoid
                 lm_ggml_tensor * cur_gate = lm_ggml_div(ctx0, lm_ggml_silu(ctx0, cur_gate_inp), cur_gate_inp);
                 cb(cur_gate, "ffn_shexp_gate", il);
 
-                lm_ggml_tensor * cur_ffn = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up_shexp,   NULL,
-                        model.layers[il].ffn_gate_shexp, NULL,
-                        model.layers[il].ffn_down_shexp, NULL,
+                lm_ggml_tensor * cur_ffn = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up_shexp,   NULL, NULL,
+                        model.layers[il].ffn_gate_shexp, NULL, NULL,
+                        model.layers[il].ffn_down_shexp, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur_ffn, "ffn_shexp", il);
@@ -9247,6 +10688,7 @@ struct llm_build_context {
             }
 
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -9261,7 +10703,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -9303,7 +10745,7 @@ struct llm_build_context {
                 struct lm_ggml_tensor * Vcur = nullptr;
 
                 if (model.layers[il].wqkv) {
-                    cur = lm_ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
+                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output);
                     cb(cur, "wqkv", il);
 
                     cur = lm_ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -9313,9 +10755,9 @@ struct llm_build_context {
                     Kcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
                     Vcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
                 } else {
-                    Qcur = lm_ggml_add(ctx0, lm_ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
-                    Kcur = lm_ggml_add(ctx0, lm_ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
-                    Vcur = lm_ggml_add(ctx0, lm_ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+                    Qcur = lm_ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
+                    Kcur = lm_ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
+                    Vcur = lm_ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
                 }
 
                 cb(Qcur, "Qcur", il);
@@ -9326,7 +10768,7 @@ struct llm_build_context {
                 Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
                 Qcur = lm_ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
+                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
@@ -9337,12 +10779,12 @@ struct llm_build_context {
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
+                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
@@ -9357,21 +10799,21 @@ struct llm_build_context {
 
             // FF
             {
-                ffn_output = llm_build_ffn(ctx0, attn_norm_output,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                ffn_output = llm_build_ffn(ctx0, lctx, attn_norm_output,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(ffn_output, "ffn_out", il);
             }
 
             cur = lm_ggml_add(ctx0, cur, ffn_output);
-            cb(cur, "l_out", il);
-
             cur = lm_ggml_add(ctx0, cur, inpL);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
+            // input for next layer
             inpL = cur;
         }
 
@@ -9381,7 +10823,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output_no_bias", -1);
 
         cur = lm_ggml_add(ctx0, cur, model.output_b);
@@ -9427,7 +10869,7 @@ struct llm_build_context {
                 struct lm_ggml_tensor * Vcur = nullptr;
 
                 if (model.layers[il].wqkv) {
-                    cur = lm_ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
+                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output);
                     cb(cur, "wqkv", il);
 
                     Qcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
@@ -9435,9 +10877,9 @@ struct llm_build_context {
                     Vcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
                 }
                 else {
-                    Qcur = lm_ggml_add(ctx0, lm_ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
-                    Kcur = lm_ggml_add(ctx0, lm_ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
-                    Vcur = lm_ggml_add(ctx0, lm_ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+                    Qcur = lm_ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
+                    Kcur = lm_ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
+                    Vcur = lm_ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
                 }
 
                 cb(Qcur, "Qcur", il);
@@ -9448,7 +10890,7 @@ struct llm_build_context {
                 Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
                 Qcur = lm_ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx,
+                    ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
@@ -9457,12 +10899,12 @@ struct llm_build_context {
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx,
+                    ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
@@ -9486,25 +10928,20 @@ struct llm_build_context {
             // special-case: the up and gate tensors are merged into a single tensor
             // TOOD: support into llm_build_ffn
             {
-                struct lm_ggml_tensor* up = lm_ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
-                cb(up, "ffn_up", il);
-
-                auto g = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], lm_ggml_row_size(up->type, up->ne[0]), 0));
-                auto y = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], lm_ggml_row_size(up->type, up->ne[0]), up->nb[1] / 2));
-
-                y = lm_ggml_mul(ctx0, y, lm_ggml_silu(ctx0, g));
-                cb(y, "ffn_gate", il);
-
-                auto down = lm_ggml_mul_mat(ctx0, model.layers[il].ffn_down, y);
-                cb(down, "ffn_down", il);
-
-                cur = down;
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        NULL,                      NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = lm_ggml_add(ctx0, residual, cur);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
+            // input for next layer
             inpL = cur;
         }
 
@@ -9514,7 +10951,7 @@ struct llm_build_context {
             LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -9554,28 +10991,28 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = lm_ggml_rope_ext(
                         ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_rot, n_head,    n_tokens), inp_pos, nullptr,
-                        n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                        n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                         ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
-                        n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                        n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9593,19 +11030,18 @@ struct llm_build_context {
 
             // feed-forward network
             {
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up, NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = lm_ggml_add(ctx0, cur, sa_out);
-            cb(cur, "l_out", il);
-
             cur = lm_ggml_add(ctx0, cur, inpL);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -9620,7 +11056,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -9662,7 +11098,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = lm_ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = lm_ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -9678,7 +11114,7 @@ struct llm_build_context {
 
                 Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9702,17 +11138,21 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
-            inpL = lm_ggml_add(ctx0, cur, ffn_inp);
-            cb(inpL, "l_out", il);
+            cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
         }
 
         cur = llm_build_norm(ctx0, inpL, hparams,
@@ -9721,7 +11161,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -9757,7 +11197,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = lm_ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = lm_ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -9773,19 +11213,19 @@ struct llm_build_context {
 
                 struct lm_ggml_tensor * Qcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 struct lm_ggml_tensor * Kcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9809,17 +11249,21 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
-            inpL = lm_ggml_add(ctx0, cur, ffn_inp);
-            cb(inpL, "l_out", il);
+            cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
         }
 
         cur = llm_build_norm(ctx0, inpL, hparams,
@@ -9828,7 +11272,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -9866,21 +11310,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 // if (model.layers[il].bq) {
                 //     Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
                 //     cb(Qcur, "Qcur", il);
                 // }
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 // if (model.layers[il].bk) {
                 //     Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
                 //     cb(Kcur, "Kcur", il);
                 // }
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 // if (model.layers[il].bv) {
                 //     Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -9889,19 +11333,19 @@ struct llm_build_context {
 
                 Qcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9922,15 +11366,16 @@ struct llm_build_context {
                     LLM_NORM, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   NULL,
-                    model.layers[il].ffn_gate, NULL,
-                    model.layers[il].ffn_down, NULL,
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
                     NULL,
                     LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
             cb(cur, "ffn_out", il);
 
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -9945,7 +11390,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -9983,21 +11428,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -10006,19 +11451,19 @@ struct llm_build_context {
 
                 Qcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10039,15 +11484,16 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   NULL,
-                    model.layers[il].ffn_gate, NULL,
-                    model.layers[il].ffn_down, NULL,
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
                     NULL,
                     LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
             cb(cur, "ffn_out", il);
 
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -10062,7 +11508,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -10113,21 +11559,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -10136,19 +11582,19 @@ struct llm_build_context {
 
                 Qcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10175,10 +11621,10 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
@@ -10189,6 +11635,7 @@ struct llm_build_context {
             cb(cur, "hidden_scaled_ffn", -1);
 
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -10208,7 +11655,7 @@ struct llm_build_context {
         cb(cur, "lmhead_scaling", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.tok_embd, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -10245,18 +11692,18 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = lm_ggml_rope_ext(
                         ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
-                        n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur", il);
 
@@ -10265,11 +11712,11 @@ struct llm_build_context {
 
                 Kcur = lm_ggml_rope_ext(
                         ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
-                        n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
@@ -10291,16 +11738,17 @@ struct llm_build_context {
 
             // feed-forward network
             {
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up, NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = lm_ggml_add(ctx0, cur, sa_out);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -10315,7 +11763,141 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        lm_ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    struct lm_ggml_cgraph * build_gemma2() {
+        struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head_k = hparams.n_embd_head_k;
+
+        struct lm_ggml_tensor * cur;
+        struct lm_ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        inpL = lm_ggml_scale(ctx0, inpL, sqrtf(n_embd));
+        cb(inpL, "inp_scaled", -1);
+
+        // inp_pos - contains the positions
+        struct lm_ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        // gemma 2 requires different mask for layers using sliding window (SWA)
+        struct lm_ggml_tensor * KQ_mask     = build_inp_KQ_mask(true);
+        struct lm_ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
+
+        for (int il = 0; il < n_layer; ++il) {
+            // (il % 2) layers use SWA
+            struct lm_ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = lm_ggml_rope_ext(
+                        ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Qcur, "Qcur", il);
+
+                // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
+                switch (model.type) {
+                    case e_model::MODEL_9B:  Qcur = lm_ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
+                    case e_model::MODEL_27B: Qcur = lm_ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
+                    default: LM_GGML_ASSERT(false);
+                };
+                cb(Qcur, "Qcur_scaled", il);
+
+                Kcur = lm_ggml_rope_ext(
+                        ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+            }
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_post_norm", il);
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = lm_ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            struct lm_ggml_tensor * sa_out = lm_ggml_add(ctx0, cur, inpL);
+            cb(sa_out, "sa_out", il);
+
+            cur = llm_build_norm(ctx0, sa_out, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            {
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                model.layers[il].ffn_post_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+            cb(cur, "ffn_post_norm", -1);
+
+            cur = lm_ggml_add(ctx0, cur, sa_out);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        // final logit soft-capping
+        cur = lm_ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
+        cur = lm_ggml_tanh(ctx0, cur);
+        cur = lm_ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
+
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -10323,6 +11905,7 @@ struct llm_build_context {
         return gf;
     }
 
+
     struct lm_ggml_cgraph * build_starcoder2() {
         struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
@@ -10353,21 +11936,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -10376,19 +11959,19 @@ struct llm_build_context {
 
                 Qcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10410,14 +11993,16 @@ struct llm_build_context {
                     LLM_NORM, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
             cb(cur, "ffn_out", il);
+
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -10432,7 +12017,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -10484,7 +12069,7 @@ struct llm_build_context {
             cb(cur, "attn_norm", il);
 
             // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens}
-            struct lm_ggml_tensor * xz = lm_ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur);
+            struct lm_ggml_tensor * xz = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_in, cur);
             // split the above in two
             // => {d_inner, n_tokens}
             struct lm_ggml_tensor * x = lm_ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
@@ -10523,14 +12108,14 @@ struct llm_build_context {
             // ssm
             {
                 // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
-                struct lm_ggml_tensor * x_db = lm_ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
+                struct lm_ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_x, x);
                 // split
                 struct lm_ggml_tensor * dt = lm_ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
                 struct lm_ggml_tensor * B  = lm_ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], lm_ggml_element_size(x_db)*dt_rank);
                 struct lm_ggml_tensor * C  = lm_ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], lm_ggml_element_size(x_db)*(dt_rank+d_state));
 
                 // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
-                dt = lm_ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
+                dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
                 dt = lm_ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
 
                 // Custom operator to optimize the parallel associative scan
@@ -10561,11 +12146,12 @@ struct llm_build_context {
                 y = lm_ggml_mul(ctx0, y, lm_ggml_silu(ctx0, z));
 
                 // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens}
-                cur = lm_ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, y);
             }
 
             // residual
             cur = lm_ggml_add(ctx0, cur, inpL);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -10579,7 +12165,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -10618,21 +12204,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -10666,19 +12252,19 @@ struct llm_build_context {
 
                 Qcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10695,10 +12281,10 @@ struct llm_build_context {
 
             // feed-forward network
             {
-                cur = llm_build_ffn(ctx0, ffn_inp,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, ffn_inp,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
@@ -10707,6 +12293,7 @@ struct llm_build_context {
             // add together residual + FFN + self-attention
             cur = lm_ggml_add(ctx0, cur, inpL);
             cur = lm_ggml_add(ctx0, cur, attn_out);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -10721,7 +12308,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
 
         if (f_logit_scale) {
             cur = lm_ggml_scale(ctx0, cur, f_logit_scale);
@@ -10774,21 +12361,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (hparams.f_clamp_kqv > 0.0f) {
                     Qcur = lm_ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (hparams.f_clamp_kqv > 0.0f) {
                     Kcur = lm_ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (hparams.f_clamp_kqv > 0.0f) {
                     Vcur = lm_ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
@@ -10797,19 +12384,19 @@ struct llm_build_context {
 
                 Qcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, nullptr,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10831,10 +12418,10 @@ struct llm_build_context {
                     LLM_NORM, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   NULL,
-                    model.layers[il].ffn_gate, NULL,
-                    model.layers[il].ffn_down, NULL,
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
                     NULL,
                     LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
             cb(cur, "ffn_out", il);
@@ -10842,10 +12429,7 @@ struct llm_build_context {
             cur = lm_ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
 
-            lm_ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-            if (layer_dir != nullptr) {
-                cur = lm_ggml_add(ctx0, cur, layer_dir);
-            }
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -10860,7 +12444,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -10868,16 +12452,14 @@ struct llm_build_context {
         return gf;
     }
 
-    struct lm_ggml_cgraph * build_gptneox() {
+    struct lm_ggml_cgraph * build_openelm() {
         struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
         LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
 
         struct lm_ggml_tensor * cur;
         struct lm_ggml_tensor * inpL;
-
         inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
@@ -10887,15 +12469,142 @@ struct llm_build_context {
         struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         for (int il = 0; il < n_layer; ++il) {
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
+            const int64_t n_head    = hparams.n_head(il);
+            const int64_t n_head_kv = hparams.n_head_kv(il);
+            const int64_t n_head_qkv = 2*n_head_kv + n_head;
 
-            // self-attention
+            cur = inpL;
+            struct lm_ggml_tensor * residual = cur;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = lm_ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens);
+
+                struct lm_ggml_tensor * Qcur = lm_ggml_cont(ctx0, lm_ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0));
+                cb(Qcur, "Qcur", il);
+
+                struct lm_ggml_tensor * Kcur = lm_ggml_cont(ctx0, lm_ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head));
+                cb(Kcur, "Kcur", il);
+
+                struct lm_ggml_tensor * Vcur = lm_ggml_cont(ctx0, lm_ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)));
+                cb(Vcur, "Vcur", il);
+
+                Qcur = llm_build_norm(ctx0, Qcur, hparams,
+                        model.layers[il].attn_q_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(Qcur, "Qcur", il);
+
+                Kcur = llm_build_norm(ctx0, Kcur, hparams,
+                        model.layers[il].attn_k_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(Kcur, "Kcur", il);
+
+                Qcur = lm_ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = lm_ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                Vcur = lm_ggml_reshape_2d(ctx0, Vcur, n_embd_head * n_head_kv, n_tokens);
+                cb(Qcur, "Vcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
+                residual = lm_ggml_get_rows(ctx0, residual, inp_out_ids);
+                cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
+            }
+
+            struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, residual, cur);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        // norm
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        lm_ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    struct lm_ggml_cgraph * build_gptneox() {
+        struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        struct lm_ggml_tensor * cur;
+        struct lm_ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct lm_ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
             {
-                cur = lm_ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = lm_ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -10911,19 +12620,19 @@ struct llm_build_context {
 
                 Qcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10948,10 +12657,10 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
@@ -10959,8 +12668,12 @@ struct llm_build_context {
                 cur = lm_ggml_add(ctx0, cur, inpL);
                 cb(cur, "ffn_out", il);
 
-                inpL = lm_ggml_add(ctx0, cur, attn_out);
-                cb(inpL, "l_out", il);
+                cur = lm_ggml_add(ctx0, cur, attn_out);
+                cur = lctx.cvec.apply_to(ctx0, cur, il);
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
             } else {
                 // attention and ffn are computed sequentially
                 // x = x + attn(ln1(x))
@@ -10975,16 +12688,20 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
 
-                inpL = lm_ggml_add(ctx0, cur, ffn_inp);
-                cb(inpL, "l_out", il);
+                cur = lm_ggml_add(ctx0, cur, ffn_inp);
+                cur = lctx.cvec.apply_to(ctx0, cur, il);
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
             }
         }
 
@@ -10994,7 +12711,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -11035,30 +12752,30 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = lm_ggml_rope_ext(
                     ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -11080,10 +12797,10 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   NULL,
-                    model.layers[il].ffn_gate, NULL,
-                    model.layers[il].ffn_down, NULL,
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
                     NULL,
                     LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
             cb(cur, "ffn_out", il);
@@ -11097,7 +12814,7 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm_exps", il);
 
-            cur = llm_build_moe_ffn(ctx0, cur,
+            cur = llm_build_moe_ffn(ctx0, lctx, cur,
                     model.layers[il].ffn_gate_inp,
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
@@ -11111,10 +12828,7 @@ struct llm_build_context {
             cur = lm_ggml_add(ctx0, cur, ffn_out);
             cb(cur, "ffn_out", il);
 
-            lm_ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-            if (layer_dir != nullptr) {
-                cur = lm_ggml_add(ctx0, cur, layer_dir);
-            }
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -11129,7 +12843,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
@@ -11198,57 +12912,81 @@ struct llm_build_context {
                 }
 
                 // split into {n_head * n_embd_head_qk_nope, n_tokens}
-                struct lm_ggml_tensor * q_nope = lm_ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, lm_ggml_element_size(q) * hparams.n_embd_head_k, lm_ggml_element_size(q) * hparams.n_embd_head_k * n_head, 0);
+                struct lm_ggml_tensor * q_nope = lm_ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
+                        lm_ggml_row_size(q->type, hparams.n_embd_head_k),
+                        lm_ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                        0);
                 cb(q_nope, "q_nope", il);
+
                 // and {n_head * n_embd_head_qk_rope, n_tokens}
-                struct lm_ggml_tensor * q_pe = lm_ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, lm_ggml_element_size(q) * hparams.n_embd_head_k, lm_ggml_element_size(q) * hparams.n_embd_head_k * n_head, lm_ggml_element_size(q) * n_embd_head_qk_nope);
+                struct lm_ggml_tensor * q_pe = lm_ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
+                        lm_ggml_row_size(q->type, hparams.n_embd_head_k),
+                        lm_ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                        lm_ggml_row_size(q->type, n_embd_head_qk_nope));
                 cb(q_pe, "q_pe", il);
 
                 // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
-                struct lm_ggml_tensor * compressed_kv_pe = lm_ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
-                cb(compressed_kv_pe, "compressed_kv_pe", il);
+                struct lm_ggml_tensor * kv_pe_compresseed = lm_ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
+                cb(kv_pe_compresseed, "kv_pe_compresseed", il);
 
                 // split into {kv_lora_rank, n_tokens}
-                struct lm_ggml_tensor * compressed_kv = lm_ggml_view_2d(ctx0, compressed_kv_pe, kv_lora_rank, n_tokens, compressed_kv_pe->nb[1], 0);
-                cb(compressed_kv, "compressed_kv", il);
+                struct lm_ggml_tensor * kv_compressed = lm_ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
+                        kv_pe_compresseed->nb[1],
+                        0);
+                cb(kv_compressed, "kv_compressed", il);
+
                 // and {n_embd_head_qk_rope, n_tokens}
-                struct lm_ggml_tensor * k_pe = lm_ggml_view_2d(ctx0, compressed_kv_pe, n_embd_head_qk_rope, n_tokens, compressed_kv_pe->nb[1], lm_ggml_element_size(compressed_kv_pe)*kv_lora_rank);
+                struct lm_ggml_tensor * k_pe = lm_ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
+                        kv_pe_compresseed->nb[1],
+                        kv_pe_compresseed->nb[1],
+                        lm_ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
                 cb(k_pe, "k_pe", il);
 
-                compressed_kv = llm_build_norm(ctx0, compressed_kv, hparams,
+                kv_compressed = lm_ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
+                kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
                         model.layers[il].attn_kv_a_norm, NULL,
                         LLM_NORM_RMS, cb, il);
-                cb(compressed_kv, "compressed_kv", il);
+                cb(kv_compressed, "kv_compressed", il);
 
                 // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
-                struct lm_ggml_tensor * kv = lm_ggml_mul_mat(ctx0, model.layers[il].wkv_b, compressed_kv);
+                struct lm_ggml_tensor * kv = lm_ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
                 cb(kv, "kv", il);
 
                 // split into {n_head * n_embd_head_qk_nope, n_tokens}
-                struct lm_ggml_tensor * k_nope = lm_ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, lm_ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), lm_ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
+                struct lm_ggml_tensor * k_nope = lm_ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
+                        lm_ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
+                        lm_ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                        0);
                 cb(k_nope, "k_nope", il);
 
                 // and {n_head * n_embd_head_v, n_tokens}
-                struct lm_ggml_tensor * v_states = lm_ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, lm_ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), lm_ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), lm_ggml_element_size(kv) * n_embd_head_qk_nope);
+                struct lm_ggml_tensor * v_states = lm_ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
+                        lm_ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                        lm_ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
+                        lm_ggml_row_size(kv->type, (n_embd_head_qk_nope)));
                 cb(v_states, "v_states", il);
 
                 v_states = lm_ggml_cont(ctx0, v_states);
                 cb(v_states, "v_states", il);
 
-                v_states = lm_ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, lm_ggml_element_size(kv) * hparams.n_embd_head_v * n_head, 0);
+                v_states = lm_ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
+                    lm_ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
+                    0);
                 cb(v_states, "v_states", il);
 
+                q_pe = lm_ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
                 q_pe = lm_ggml_rope_ext(
                     ctx0, q_pe, inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor_scaled, beta_fast, beta_slow
                 );
                 cb(q_pe, "q_pe", il);
 
                 // shared RoPE key
+                k_pe = lm_ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
                 k_pe = lm_ggml_rope_ext(
-                    ctx0, lm_ggml_view_3d(ctx0, k_pe, n_embd_head_qk_rope, 1, n_tokens, k_pe->nb[0], k_pe->nb[1], 0), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ctx0, k_pe, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor_scaled, beta_fast, beta_slow
                 );
                 cb(k_pe, "k_pe", il);
@@ -11256,95 +12994,753 @@ struct llm_build_context {
                 struct lm_ggml_tensor * q_states = lm_ggml_concat(ctx0, q_nope, q_pe, 0);
                 cb(q_states, "q_states", il);
 
-                struct lm_ggml_tensor * k_states = lm_ggml_concat(ctx0, k_nope, lm_ggml_repeat(ctx0, k_pe, q_pe), 0);
-                cb(k_states, "k_states", il);
+                struct lm_ggml_tensor * k_states = lm_ggml_concat(ctx0, k_nope, lm_ggml_repeat(ctx0, k_pe, q_pe), 0);
+                cb(k_states, "k_states", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = lm_ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            if ((uint32_t) il < hparams.n_layer_dense_lead) {
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                lm_ggml_tensor * moe_out =
+                        llm_build_moe_ffn(ctx0, lctx, cur,
+                            model.layers[il].ffn_gate_inp,
+                            model.layers[il].ffn_up_exps,
+                            model.layers[il].ffn_gate_exps,
+                            model.layers[il].ffn_down_exps,
+                            n_expert, n_expert_used,
+                            LLM_FFN_SILU, false,
+                            true, hparams.expert_weights_scale,
+                            cb, il);
+                cb(moe_out, "ffn_moe_out", il);
+
+                // FFN shared expert
+                {
+                    lm_ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur,
+                            model.layers[il].ffn_up_shexp,   NULL, NULL,
+                            model.layers[il].ffn_gate_shexp, NULL, NULL,
+                            model.layers[il].ffn_down_shexp, NULL, NULL,
+                            NULL,
+                            LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                    cb(ffn_shexp, "ffn_shexp", il);
+
+                    cur = lm_ggml_add(ctx0, moe_out, ffn_shexp);
+                    cb(cur, "ffn_out", il);
+                }
+            }
+
+            cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        lm_ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    struct lm_ggml_cgraph * build_bitnet() {
+        struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        struct lm_ggml_tensor * cur;
+        struct lm_ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct lm_ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct lm_ggml_tensor * inpSA = inpL;
+
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                Qcur = lm_ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                // B1.K
+                struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                Kcur = lm_ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                // B1.V
+                struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                Vcur = lm_ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = lm_ggml_rope_ext(
+                    ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = lm_ggml_rope_ext(
+                    ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        NULL, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+
+                cur = llm_build_norm(ctx0, cur, hparams,
+                        model.layers[il].attn_sub_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "attn_sub_norm", il);
+
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
+                cur = lm_ggml_mul(ctx0, cur, model.layers[il].wo_scale);
+                if (model.layers[il].bo) {
+                    cur = lm_ggml_add(ctx0, cur, model.layers[il].bo);
+                }
+                cb(cur, "attn_o_out", il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = lm_ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward forward
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, model.layers[il].ffn_up_scale,
+                    model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale,
+                    NULL,                      NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_sub_out", il);
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                            model.layers[il].ffn_sub_norm, NULL,
+                            LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_sub_norm", il);
+
+            cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur);
+            cur = lm_ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
+            cb(cur, "ffn_down", il);
+
+            cur = lm_ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.tok_embd, cur);
+        cb(cur, "result_output", -1);
+
+        lm_ggml_build_forward_expand(gf, cur);
+        return gf;
+    }
+
+    struct lm_ggml_cgraph * build_t5() {
+        struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        struct lm_ggml_tensor * cur;
+        struct lm_ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        if (lctx.is_encoding) {
+            struct lm_ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false);
+
+            // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+            struct lm_ggml_tensor * KQ_mask_enc = build_inp_KQ_mask(false);
+
+            for (int il = 0; il < n_layer; ++il) {
+                struct lm_ggml_tensor * inpSA = inpL;
+
+                // norm
+                cur = llm_build_norm(ctx0, inpL, hparams,
+                        model.layers[il].attn_norm_enc, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "attn_norm", il);
+
+                // self-attention
+                {
+                    struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq_enc, cur);
+                    cb(Qcur, "Qcur", il);
+
+                    struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk_enc, cur);
+                    cb(Kcur, "Kcur", il);
+
+                    struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv_enc, cur);
+                    cb(Vcur, "Vcur", il);
+
+                    Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                    Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                    struct lm_ggml_tensor * q =                 lm_ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+                    struct lm_ggml_tensor * k = lm_ggml_cont(ctx0, lm_ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
+
+                    struct lm_ggml_tensor * kq = lm_ggml_mul_mat(ctx0, k, q);
+                    cb(kq, "kq", il);
+
+                    struct lm_ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc;
+                    struct lm_ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_enc, attn_rel_b);
+                    struct lm_ggml_tensor * kq_b = lm_ggml_add(ctx0, kq, pos_bias);
+                    cb(kq_b, "kq_b", il);
+
+                    kq = lm_ggml_soft_max_ext(ctx0, kq_b, KQ_mask_enc, 1.0f, hparams.f_max_alibi_bias);
+                    cb(kq, "kq_soft_max_ext", il);
+
+                    struct lm_ggml_tensor * v = lm_ggml_cont(ctx0, lm_ggml_transpose(ctx0, lm_ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
+                    cb(v, "v", il);
+
+                    struct lm_ggml_tensor * kqv = lm_ggml_mul_mat(ctx0, lm_ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
+                    cb(kqv, "kqv", il);
+
+                    struct lm_ggml_tensor * kqv_merged = lm_ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                    cb(kqv_merged, "kqv_merged", il);
+
+                    cur = lm_ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+                    cb(cur, "kqv_merged_cont", il);
+
+                    lm_ggml_build_forward_expand(gf, cur);
+
+                    cur = lm_ggml_mul_mat(ctx0, model.layers[il].wo_enc, cur);
+                    cb(cur, "kqv_out", il);
+                }
+
+                if (il == n_layer - 1) {
+                    // skip computing output for unused tokens
+                    struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
+                    n_tokens = n_outputs;
+                    cur   = lm_ggml_get_rows(ctx0,   cur, inp_out_ids);
+                    inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
+                }
+
+                struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
+                cb(ffn_inp, "ffn_inp", il);
+
+                // feed-forward network
+                {
+                    cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                            model.layers[il].ffn_norm_enc, NULL,
+                            LLM_NORM_RMS, cb, il);
+                    cb(cur, "ffn_norm", il);
+
+                    // T5 uses relu, flan-T5 uses gelu-gated
+                    cur = llm_build_ffn(ctx0, lctx, cur,
+                            model.layers[il].ffn_up_enc,   NULL, NULL,
+                            model.layers[il].ffn_gate_enc, NULL, NULL,
+                            model.layers[il].ffn_down_enc, NULL, NULL,
+                            NULL,
+                            model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
+                            model.layers[il].ffn_gate_enc ? LLM_FFN_PAR  : LLM_FFN_SEQ,
+                            cb, il);
+                    cb(cur, "ffn_out", il);
+                }
+
+                cur = lm_ggml_add(ctx0, cur, ffn_inp);
+                cb(cur, "ffn_out", il);
+
+                lm_ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+                if (layer_dir != nullptr) {
+                    cur = lm_ggml_add(ctx0, cur, layer_dir);
+                }
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
+            }
+
+            cur = inpL;
+            cb(cur, "result_embd", -1);
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                    model.output_norm_enc, NULL,
+                    LLM_NORM_RMS, cb, -1);
+            cb(cur, "result_norm", -1);
+        } else {
+            LM_GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first");
+
+            struct lm_ggml_tensor * embd_enc       = llm_build_inp_embd_enc();
+            struct lm_ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true);
+
+            struct lm_ggml_tensor * KQ_mask_dec   = build_inp_KQ_mask();
+            struct lm_ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross();
+
+            for (int il = 0; il < n_layer; ++il) {
+                struct lm_ggml_tensor * inpSA = inpL;
+
+                // norm
+                cur = llm_build_norm(ctx0, inpL, hparams,
+                        model.layers[il].attn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "attn_norm", il);
+
+                // self-attention
+                {
+                    struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                    cb(Qcur, "Qcur", il);
+
+                    struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                    cb(Kcur, "Kcur", il);
+
+                    struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                    cb(Vcur, "Vcur", il);
+
+                    llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+
+                    struct lm_ggml_tensor * k =
+                        lm_ggml_view_3d(ctx0, kv_self.k_l[il],
+                                n_embd_head_k, n_kv, n_head_kv,
+                                lm_ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
+                                lm_ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
+                                0);
+                    cb(k, "k", il);
+
+                    struct lm_ggml_tensor * v =
+                        lm_ggml_view_3d(ctx0, kv_self.v_l[il],
+                                n_kv, n_embd_head_v, n_head_kv,
+                                lm_ggml_element_size(kv_self.v_l[il])*n_ctx,
+                                lm_ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
+                                0);
+                    cb(v, "v", il);
+
+                    Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                    struct lm_ggml_tensor * q = lm_ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+
+                    struct lm_ggml_tensor * kq = lm_ggml_mul_mat(ctx0, k, q);
+                    cb(kq, "kq", il);
+
+                    struct lm_ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
+                    struct lm_ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b);
+                    struct lm_ggml_tensor * kq_b = lm_ggml_add(ctx0, kq, pos_bias);
+                    cb(kq_b, "kq_b", il);
+
+                    kq = lm_ggml_soft_max_ext(ctx0, kq_b, KQ_mask_dec, 1.0f, hparams.f_max_alibi_bias);
+                    cb(kq, "kq_soft_max_ext", il);
+
+                    struct lm_ggml_tensor * kqv = lm_ggml_mul_mat(ctx0, v, kq);
+                    cb(kqv, "kqv", il);
+
+                    struct lm_ggml_tensor * kqv_merged = lm_ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                    cb(kqv_merged, "kqv_merged", il);
+
+                    cur = lm_ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+                    cb(cur, "kqv_merged_cont", il);
+
+                    lm_ggml_build_forward_expand(gf, cur);
+
+                    cur = lm_ggml_mul_mat(ctx0, model.layers[il].wo, cur);
+                    cb(cur, "kqv_out", il);
+                }
+
+                cur = lm_ggml_add(ctx0, cur, inpSA);
+                cb(cur, "cross_inp", il);
+
+                struct lm_ggml_tensor * inpCA = cur;
+
+                // norm
+                cur = llm_build_norm(ctx0, cur, hparams,
+                        model.layers[il].attn_norm_cross, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "attn_norm_cross", il);
+
+                // cross-attention
+                {
+                    struct lm_ggml_tensor * Qcur = lm_ggml_mul_mat(ctx0, model.layers[il].wq_cross, cur);
+                    cb(Qcur, "Qcur", il);
+
+                    struct lm_ggml_tensor * Kcur = lm_ggml_mul_mat(ctx0, model.layers[il].wk_cross, embd_enc);
+                    cb(Kcur, "Kcur", il);
+
+                    struct lm_ggml_tensor * Vcur = lm_ggml_mul_mat(ctx0, model.layers[il].wv_cross, embd_enc);
+                    cb(Vcur, "Vcur", il);
+
+                    Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                    Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
+
+                    struct lm_ggml_tensor * q =                 lm_ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+                    struct lm_ggml_tensor * k = lm_ggml_cont(ctx0, lm_ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
+
+                    struct lm_ggml_tensor * kq = lm_ggml_mul_mat(ctx0, k, q);
+                    cb(kq, "kq", il);
+
+                    kq = lm_ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias);
+                    cb(kq, "kq_soft_max_ext", il);
+
+                    struct lm_ggml_tensor * v = lm_ggml_cont(ctx0, lm_ggml_transpose(ctx0, lm_ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc)));
+                    cb(v, "v", il);
+
+                    struct lm_ggml_tensor * kqv = lm_ggml_mul_mat(ctx0, lm_ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq);
+                    cb(kqv, "kqv", il);
+
+                    struct lm_ggml_tensor * kqv_merged = lm_ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                    cb(kqv_merged, "kqv_merged", il);
+
+                    cur = lm_ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+                    cb(cur, "kqv_merged_cont", il);
+
+                    lm_ggml_build_forward_expand(gf, cur);
+
+                    cur = lm_ggml_mul_mat(ctx0, model.layers[il].wo_cross, cur);
+                    cb(cur, "kqv_out", il);
+                }
+
+                if (il == n_layer - 1) {
+                    // skip computing output for unused tokens
+                    struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
+                    n_tokens = n_outputs;
+                    cur   = lm_ggml_get_rows(ctx0,   cur, inp_out_ids);
+                    inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
+                    inpCA = lm_ggml_get_rows(ctx0, inpCA, inp_out_ids);
+                }
+
+                struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpCA);
+                cb(ffn_inp, "ffn_inp", il);
+
+                // feed-forward network
+                {
+                    cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                            model.layers[il].ffn_norm, NULL,
+                            LLM_NORM_RMS, cb, il);
+                    cb(cur, "ffn_norm", il);
+
+                    // T5 uses relu, flan-T5 uses gelu-gated
+                    cur = llm_build_ffn(ctx0, lctx, cur,
+                            model.layers[il].ffn_up,   NULL, NULL,
+                            model.layers[il].ffn_gate, NULL, NULL,
+                            model.layers[il].ffn_down, NULL, NULL,
+                            NULL,
+                            model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
+                            model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
+                            cb, il);
+                    cb(cur, "ffn_out", il);
+                }
+
+                cur = lm_ggml_add(ctx0, cur, ffn_inp);
+                cb(cur, "ffn_out", il);
+
+                lm_ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+                if (layer_dir != nullptr) {
+                    cur = lm_ggml_add(ctx0, cur, layer_dir);
+                }
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
+            }
+
+            cur = inpL;
+            cb(cur, "result_embd", -1);
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                    model.output_norm, NULL,
+                    LLM_NORM_RMS, cb, -1);
+            cb(cur, "result_norm", -1);
+
+            // lm_head
+            cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+            cb(cur, "result_output", -1);
+        }
+
+        lm_ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    struct lm_ggml_cgraph * build_jais() {
+        struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        struct lm_ggml_tensor * cur;
+        struct lm_ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = lm_ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                struct lm_ggml_tensor * Qcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd)));
+                struct lm_ggml_tensor * Kcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd)));
+                struct lm_ggml_tensor * Vcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/float(n_embd_head), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = lm_ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // add the input
+            struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            inpL = lm_ggml_add(ctx0, cur, ffn_inp);
+            cb(inpL, "l_out", il);
+        }
+
+        cur = llm_build_norm(ctx0, inpL, hparams,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        cb(cur, "result_output", -1);
+
+        lm_ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    struct lm_ggml_cgraph * build_chatglm() {
+        struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        struct lm_ggml_tensor * cur;
+        struct lm_ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct lm_ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct lm_ggml_tensor * inpSA = inpL;
+
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                struct lm_ggml_tensor * Qcur = nullptr;
+                struct lm_ggml_tensor * Kcur = nullptr;
+                struct lm_ggml_tensor * Vcur = nullptr;
+
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = lm_ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                Qcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                Kcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                Vcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+                //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
+                Qcur = lm_ggml_rope_ext(
+                    ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur_rope", il);
+
+                Kcur = lm_ggml_rope_ext(
+                    ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur_rope", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+
             }
 
             if (il == n_layer - 1) {
                 // skip computing output for unused tokens
                 struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
                 cur   = lm_ggml_get_rows(ctx0,   cur, inp_out_ids);
                 inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
             }
 
+            // Add the input
             struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
             cb(ffn_inp, "ffn_inp", il);
 
-            if ((uint32_t) il < hparams.n_layer_dense_lead) {
+            // FF
+            {
                 cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
+                        model.layers[il].ffn_norm,
+                        NULL,
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        NULL,                      NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                        LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
-            } else {
-                // MoE branch
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                lm_ggml_tensor * moe_out =
-                        llm_build_moe_ffn(ctx0, cur,
-                            model.layers[il].ffn_gate_inp,
-                            model.layers[il].ffn_up_exps,
-                            model.layers[il].ffn_gate_exps,
-                            model.layers[il].ffn_down_exps,
-                            n_expert, n_expert_used,
-                            LLM_FFN_SILU, false,
-                            true, hparams.expert_weights_scale,
-                            cb, il);
-                cb(moe_out, "ffn_moe_out", il);
-
-                // FFN shared expert
-                {
-                    lm_ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, cur,
-                            model.layers[il].ffn_up_shexp,   NULL,
-                            model.layers[il].ffn_gate_shexp, NULL,
-                            model.layers[il].ffn_down_shexp, NULL,
-                            NULL,
-                            LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                    cb(ffn_shexp, "ffn_shexp", il);
 
-                    cur = lm_ggml_add(ctx0, moe_out, ffn_shexp);
-                    cb(cur, "ffn_out", il);
-                }
             }
 
-            cur = lm_ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
+            inpL = lm_ggml_add(ctx0, cur, ffn_inp);
+            cb(inpL, "l_out", il);
         }
 
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
+        cur = llm_build_norm(ctx0, inpL, hparams,
+                model.output_norm,
+                NULL,
                 LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
 
-        // lm_head
-        cur = lm_ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         lm_ggml_build_forward_expand(gf, cur);
 
         return gf;
     }
-
 };
 
 static struct lm_ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) {
@@ -11425,7 +13821,8 @@ static struct lm_ggml_cgraph * llama_build_graph(
         if (batch.n_tokens < 32 || full_offload) {
             if (il != -1 && strcmp(name, "norm") == 0) {
                 for (auto * backend : lctx.backends) {
-                    if (lm_ggml_backend_buft_supports_backend(lctx.model.buft_layer[il].buft, backend)) {
+                    if (lm_ggml_backend_supports_buft(backend, lctx.model.buft_layer[il].buft) &&
+                        (lm_ggml_backend_supports_op(backend, cur) || lm_ggml_backend_offload_op(backend, cur))) {
                         lm_ggml_backend_sched_set_tensor_backend(lctx.sched, cur, backend);
                         break;
                     }
@@ -11531,6 +13928,10 @@ static struct lm_ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_gemma();
             } break;
+        case LLM_ARCH_GEMMA2:
+            {
+                result = llm.build_gemma2();
+            } break;
         case LLM_ARCH_STARCODER2:
             {
                 result = llm.build_starcoder2();
@@ -11555,6 +13956,10 @@ static struct lm_ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_olmo();
             } break;
+        case LLM_ARCH_OPENELM:
+            {
+                result = llm.build_openelm();
+            } break;
         case LLM_ARCH_GPTNEOX:
             {
                 result = llm.build_gptneox();
@@ -11567,10 +13972,31 @@ static struct lm_ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_deepseek2();
             } break;
+        case LLM_ARCH_CHATGLM:
+            {
+                result = llm.build_chatglm();
+            } break;
+        case LLM_ARCH_BITNET:
+            {
+                result = llm.build_bitnet();
+            } break;
+        case LLM_ARCH_T5:
+            {
+                result = llm.build_t5();
+            } break;
+        case LLM_ARCH_JAIS:
+            {
+                result = llm.build_jais();
+            } break;
         default:
             LM_GGML_ASSERT(false);
     }
 
+    // add on pooling layer
+    if (lctx.cparams.embeddings) {
+        result = llm.append_pooling(result);
+    }
+
     llm.free();
 
     return result;
@@ -11600,6 +14026,30 @@ static void llama_set_s_copy(llama_context & lctx) {
     }
 }
 
+static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
+    // TODO move to hparams if a T5 variant appears that uses a different value
+    const int64_t max_distance = 128;
+
+    if (bidirectional) {
+        n_buckets >>= 1;
+    }
+
+    const int64_t max_exact = n_buckets >> 1;
+
+    int32_t relative_position = x - y;
+    int32_t relative_bucket = 0;
+    if (bidirectional) {
+        relative_bucket += (relative_position > 0) * n_buckets;
+        relative_position = abs(relative_position);
+    } else {
+        relative_position = -std::min(relative_position, 0);
+    }
+    int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
+    relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1);
+    relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
+    return relative_bucket;
+}
+
 static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
     //
     // set input data
@@ -11660,18 +14110,23 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         // (!a || b) is a logical implication (a -> b)
         // !hparams.causal_attn -> !cparams.causal_attn
         (hparams.causal_attn || !cparams.causal_attn) &&
-        "causal attention with embedding models is not supported"
+        "causal attention is not supported by this model"
     );
 
     if (lctx.inp_KQ_mask) {
         // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
-        if (cparams.causal_attn) {
+        if (cparams.causal_attn && !lctx.is_encoding) {
             const int64_t n_kv     = kv_self.n;
             const int64_t n_tokens = batch.n_tokens;
 
             LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
 
-            float * data = (float *) lctx.inp_KQ_mask->data;
+            float * data     = (float *) lctx.inp_KQ_mask->data;
+            float * data_swa = nullptr;
+
+            if (lctx.inp_KQ_mask_swa) {
+                data_swa = (float *) lctx.inp_KQ_mask_swa->data;
+            }
 
             // For causal attention, use only the previous KV cells
             // of the correct sequence for each token of the batch.
@@ -11693,6 +14148,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
                             }
                         }
                         data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
+
+                        // may need to cut off old tokens for sliding window
+                        if (data_swa) {
+                            if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
+                                f = -INFINITY;
+                            }
+                            data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
+                        }
                     }
                 }
 
@@ -11705,7 +14168,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         } else {
             // when using kv cache, the mask needs to match the kv cache size
             const int64_t n_tokens = batch.n_tokens;
-            const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
+            const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
 
             LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
 
@@ -11739,7 +14202,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
-    if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
+    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
         const int64_t n_tokens = batch.n_tokens;
 
         LM_GGML_ASSERT(lctx.inp_mean);
@@ -11771,7 +14234,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
-    if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
+    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
         const int64_t n_tokens = batch.n_tokens;
 
         LM_GGML_ASSERT(lctx.inp_cls);
@@ -11792,6 +14255,37 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
+    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
+        const int64_t n_tokens = batch.n_tokens;
+
+        LM_GGML_ASSERT(lctx.inp_cls);
+        LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
+
+        uint32_t * data = (uint32_t *) lctx.inp_cls->data;
+        memset(lctx.inp_cls->data, 0, n_tokens * lm_ggml_element_size(lctx.inp_cls));
+
+        std::vector last_pos(n_tokens, -1);
+        std::vector last_row(n_tokens, -1);
+
+        for (int i = 0; i < n_tokens; ++i) {
+            const llama_seq_id seq_id = batch.seq_id[i][0];
+            const llama_pos    pos    = batch.pos[i];
+
+            LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
+
+            if (pos >= last_pos[seq_id]) {
+                last_pos[seq_id] = pos;
+                last_row[seq_id] = i;
+            }
+        }
+
+        for (int i = 0; i < n_tokens; ++i) {
+            if (last_row[i] >= 0) {
+                data[i] = last_row[i];
+            }
+        }
+    }
+
     if (kv_self.recurrent) {
         const int64_t n_kv = kv_self.n;
 
@@ -11838,6 +14332,70 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
             }
         }
     }
+
+    if (lctx.inp_pos_bucket) {
+        const int64_t n_tokens = batch.n_tokens;
+
+        LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
+
+        int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
+
+        if (!lctx.is_encoding) {
+            const int64_t n_kv = kv_self.n;
+            for (int h = 0; h < 1; ++h) {
+                for (int j = 0; j < n_tokens; ++j) {
+                    for (int i = 0; i < n_kv; ++i) {
+                        data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
+                    }
+                }
+            }
+        } else {
+            for (int h = 0; h < 1; ++h) {
+                for (int j = 0; j < n_tokens; ++j) {
+                    for (int i = 0; i < n_tokens; ++i) {
+                        data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(batch.pos[i], batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
+                    }
+                }
+            }
+        }
+    }
+
+    if (!lctx.is_encoding && lctx.inp_embd_enc) {
+        assert(lctx.inp_embd_enc->type == LM_GGML_TYPE_F32);
+        assert((size_t) lm_ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size());
+
+        lm_ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, lm_ggml_nbytes(lctx.inp_embd_enc));
+    }
+
+    if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) {
+        const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd;
+        const int64_t n_tokens = batch.n_tokens;
+
+        LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
+
+        float * data = (float *) lctx.inp_KQ_mask_cross->data;
+
+        for (int h = 0; h < 1; ++h) {
+            for (int j = 0; j < n_tokens; ++j) {
+                for (int i = 0; i < n_output_enc; ++i) {
+                    float f = -INFINITY;
+                    for (int s = 0; s < batch.n_seq_id[j]; ++s) {
+                        const llama_seq_id seq_id = batch.seq_id[j][s];
+                        if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) {
+                            f = 0.0f;
+                        }
+                    }
+                    data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
+                }
+            }
+
+            for (int i = n_tokens; i < LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD); ++i) {
+                for (int j = 0; j < n_output_enc; ++j) {
+                    data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
+                }
+            }
+        }
+    }
 }
 
 // Make sure enough space is available for outputs.
@@ -11853,8 +14411,8 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
     const auto n_embd  = hparams.n_embd;
 
     // TODO: use a per-batch flag for logits presence instead
-    const bool has_logits = cparams.causal_attn;
-    const bool has_embd   = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
+    const bool has_logits = !cparams.embeddings;
+    const bool has_embd   =  lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
 
     const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
     const size_t embd_size   = has_embd   ?  n_embd*n_outputs_max : 0;
@@ -11922,6 +14480,11 @@ static void llama_graph_compute(
         lm_ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
         lm_ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
     }
+#ifdef LM_GGML_USE_BLAS
+    if (lctx.backend_blas != nullptr) {
+        lm_ggml_backend_blas_set_n_threads(lctx.backend_blas, n_threads);
+    }
+#endif
 
     lm_ggml_backend_sched_graph_compute_async(lctx.sched, gf);
 
@@ -11941,6 +14504,7 @@ static int llama_decode_internal(
          llama_context & lctx,
            llama_batch   batch_all) { // TODO: rename back to batch
 
+    lctx.is_encoding = false;
     const uint32_t n_tokens_all = batch_all.n_tokens;
 
     if (n_tokens_all == 0) {
@@ -11973,17 +14537,21 @@ static int llama_decode_internal(
 
     const auto n_ubatch = cparams.n_ubatch;
 
+    // TODO: simplify or deprecate
     std::vector pos;
     std::vector                   n_seq_id;
     std::vector            seq_id_arr;
     std::vector> seq_id;
 
+    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
+    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
+
     // count outputs
-    if (batch_all.logits) {
+    if (batch_all.logits && !embd_pooled) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
             n_outputs += batch_all.logits[i] != 0;
         }
-    } else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
+    } else if (lctx.logits_all || embd_pooled) {
         n_outputs = n_tokens_all;
     } else {
         // keep last output only
@@ -12029,7 +14597,7 @@ static int llama_decode_internal(
         {
             int32_t n_outputs_new = 0;
 
-            if (u_batch.logits) {
+            if (u_batch.logits && !embd_pooled) {
                 for (uint32_t i = 0; i < n_tokens; i++) {
                     n_outputs_new += u_batch.logits[i] != 0;
                 }
@@ -12114,47 +14682,19 @@ static int llama_decode_internal(
             // no output
             res  = nullptr;
             embd = nullptr;
-        } else if (!hparams.causal_attn) {
-            res = nullptr; // do not extract logits for embedding models such as BERT
-
-            // token or sequence embeddings
-            embd = gf->nodes[gf->n_nodes - 1];
-
-            LM_GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
         } else if (cparams.embeddings) {
-            // the embeddings could be in the second to last tensor, or any of the previous tensors
-            int i_embd = gf->n_nodes - 2;
-            for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
-                i_embd = gf->n_nodes - i;
-                if (i_embd < 0) { break; }
-                embd = gf->nodes[i_embd];
-            }
-            LM_GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");
-
-            // TODO: use a per-batch flag to know when to skip logits while keeping embeddings
-            if (!cparams.causal_attn) {
-                res = nullptr; // do not extract logits when not needed
-                // skip computing logits
-                // TODO: is this safe?
-                gf->n_nodes = i_embd + 1;
+            res = nullptr; // do not extract logits for embedding case
+            embd = gf->nodes[gf->n_nodes - 1];
+            if (strcmp(embd->name, "result_embd_pooled") != 0) {
+                embd = gf->nodes[gf->n_nodes - 2];
             }
+            LM_GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
         } else {
             embd = nullptr; // do not extract embeddings when not needed
             LM_GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
         }
         // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (lm_ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
 
-        // for big prompts, if BLAS is enabled, it is better to use only one thread
-        // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
-        // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well
-        //       we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering
-        //       with the BLAS calls. need a better solution
-        // MoE Special Case: This logic applies when hparams.n_expert == 0, i.e. the model is NOT an MoE model. When an MoE is
-        //                   being processed then Accelerate/BLAS will not be involved, so capping would limit performance.
-        if (n_tokens >= 32 && hparams.n_expert == 0 && lm_ggml_cpu_has_blas() && !lm_ggml_cpu_has_gpublas()) {
-            n_threads = std::min(4, n_threads);
-        }
-
         lm_ggml_backend_sched_alloc_graph(lctx.sched, gf);
 
         llama_set_inputs(lctx, u_batch);
@@ -12171,12 +14711,6 @@ static int llama_decode_internal(
             }
         }
 
-#ifdef LM_GGML_PERF
-        // print timing information per ggml operation (for debugging purposes)
-        // requires LM_GGML_PERF to be defined
-        lm_ggml_graph_print(gf);
-#endif
-
         // plot the computation graph in dot format (for debugging purposes)
         //if (n_past%100 == 0) {
         //    lm_ggml_graph_dump_dot(gf, NULL, "llama.dot");
@@ -12217,48 +14751,180 @@ static int llama_decode_internal(
                             lm_ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
                         }
                     } break;
-                case LLAMA_POOLING_TYPE_CLS:
                 case LLAMA_POOLING_TYPE_MEAN:
+                case LLAMA_POOLING_TYPE_CLS:
+                case LLAMA_POOLING_TYPE_LAST:
+                    {
+                        // extract sequence embeddings
+                        auto & embd_seq_out = lctx.embd_seq;
+                        embd_seq_out.clear();
+
+                        for (uint32_t i = 0; i < n_tokens; i++) {
+                            const llama_seq_id seq_id = u_batch.seq_id[i][0];
+                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
+                                continue;
+                            }
+                            embd_seq_out[seq_id].resize(n_embd);
+                            lm_ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
+                        }
+                    } break;
+                case LLAMA_POOLING_TYPE_UNSPECIFIED:
                     {
-                        LM_GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
+                        LM_GGML_ASSERT(false && "unknown pooling type");
+                    } break;
+            }
+        }
+        n_outputs_prev += lctx.n_outputs;
+    }
+
+    // set to total number of outputs in the batch, for use in llama_get_logits_ith
+    lctx.n_outputs = n_outputs;
+
+    // wait for the computation to finish (automatically done when obtaining the model output)
+    //llama_synchronize(&lctx);
+
+    // decide if we need to defrag the kv cache
+    if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
+        const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
+
+        // queue defragmentation for next llama_kv_cache_update
+        if (fragmentation > cparams.defrag_thold) {
+            //LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
+
+            llama_kv_cache_defrag(kv_self);
+        }
+    }
+
+    // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
+    // overlap with device computation.
+    lm_ggml_backend_sched_reset(lctx.sched);
+
+    return 0;
+}
+
+// encode a batch of tokens by evaluating the encoder part of the transformer
+//
+//   - lctx:      llama context
+//   - batch:     batch to evaluate
+//
+// return 0 on success
+// return positive int on warning
+// return negative int on error
+//
+static int llama_encode_internal(
+         llama_context & lctx,
+           llama_batch   batch) {
+
+    lctx.is_encoding = true;
+
+    const uint32_t n_tokens = batch.n_tokens;
+
+    if (n_tokens == 0) {
+        LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
+        return -1;
+    }
+
+    const auto & model   = lctx.model;
+    const auto & hparams = model.hparams;
+    const auto & cparams = lctx.cparams;
+
+    LM_GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
+
+    // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
+    LM_GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
+
+    if (lctx.t_compute_start_us == 0) {
+        lctx.t_compute_start_us = lm_ggml_time_us();
+    }
+
+    lctx.n_queued_tokens += n_tokens;
+
+    const int64_t n_embd = hparams.n_embd;
+
+    // TODO: simplify or deprecate
+    std::vector pos;
+    std::vector                   n_seq_id;
+    std::vector            seq_id_arr;
+    std::vector> seq_id;
+
+    // reserve output buffer
+    if (llama_output_reserve(lctx, n_tokens) < n_tokens) {
+        LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
+        return -2;
+    };
+
+    for (uint32_t i = 0; i < n_tokens; ++i) {
+        lctx.output_ids[i] = i;
+    }
+
+    lctx.inp_embd_enc = NULL;
+    lctx.n_outputs = n_tokens;
+
+    const int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
+    LM_GGML_ASSERT(n_threads > 0);
+
+    // helpers for smoother batch API transition
+    // after deprecating the llama_eval calls, these will be removed
+    if (batch.pos == nullptr) {
+        pos.resize(n_tokens);
+        for (uint32_t i = 0; i < n_tokens; i++) {
+            pos[i] = batch.all_pos_0 + i*batch.all_pos_1;
+        }
 
-                        // extract sequence embeddings
-                        auto & embd_seq_out = lctx.embd_seq;
-                        embd_seq_out.clear();
+        batch.pos = pos.data();
+    }
 
-                        for (uint32_t i = 0; i < n_tokens; i++) {
-                            const llama_seq_id seq_id = u_batch.seq_id[i][0];
-                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
-                                continue;
-                            }
-                            embd_seq_out[seq_id].resize(n_embd);
-                            lm_ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
-                        }
-                    } break;
-                case LLAMA_POOLING_TYPE_UNSPECIFIED:
-                    {
-                        LM_GGML_ASSERT(false && "unknown pooling type");
-                    } break;
-            }
+    if (batch.seq_id == nullptr) {
+        n_seq_id.resize(n_tokens);
+        seq_id.resize(n_tokens);
+        seq_id_arr.resize(n_tokens);
+        for (uint32_t i = 0; i < n_tokens; i++) {
+            n_seq_id[i] = 1;
+            seq_id[i].resize(1);
+            seq_id[i][0] = batch.all_seq_id;
+            seq_id_arr[i] = seq_id[i].data();
         }
-        n_outputs_prev += lctx.n_outputs;
+
+        batch.n_seq_id = n_seq_id.data();
+        batch.seq_id = seq_id_arr.data();
     }
 
-    // set to total number of outputs in the batch, for use in llama_get_logits_ith
-    lctx.n_outputs = n_outputs;
+    lm_ggml_backend_sched_reset(lctx.sched);
+    lm_ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 
-    // wait for the computation to finish (automatically done when obtaining the model output)
-    //llama_synchronize(&lctx);
+    lm_ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
 
-    // decide if we need to defrag the kv cache
-    if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
-        const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
+    // the output embeddings after the final encoder normalization
+    struct lm_ggml_tensor * embd = gf->nodes[gf->n_nodes - 1];
 
-        // queue defragmentation for next llama_kv_cache_update
-        if (fragmentation > cparams.defrag_thold) {
-            //LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
+    LM_GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
 
-            llama_kv_cache_defrag(kv_self);
+    lm_ggml_backend_sched_alloc_graph(lctx.sched, gf);
+
+    llama_set_inputs(lctx, batch);
+
+    llama_graph_compute(lctx, gf, n_threads);
+
+    // extract embeddings
+    if (embd) {
+        lm_ggml_backend_t backend_embd = lm_ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
+        LM_GGML_ASSERT(backend_embd != nullptr);
+
+        // extract token embeddings
+        LM_GGML_ASSERT(lctx.embd != nullptr);
+
+        lctx.embd_enc.resize(n_tokens*n_embd);
+        float * embd_out = lctx.embd_enc.data();
+
+        lm_ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
+
+        // remember the sequence ids used during the encoding - needed for cross attention later
+        lctx.seq_ids_enc.resize(n_tokens);
+        for (uint32_t i = 0; i < n_tokens; i++) {
+            for (int s = 0; s < batch.n_seq_id[i]; s++) {
+                llama_seq_id seq_id = batch.seq_id[i][s];
+                lctx.seq_ids_enc[i].insert(seq_id);
+            }
         }
     }
 
@@ -12269,7 +14935,6 @@ static int llama_decode_internal(
     return 0;
 }
 
-
 // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
 static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
     auto & kv_self = lctx.kv_self;
@@ -12496,6 +15161,10 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
 
     // apply K-shift if needed
     if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
+        if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
+            LM_GGML_ASSERT(false && "Deepseek2 does not support K-shift");
+        }
+
         {
             lm_ggml_backend_sched_reset(lctx.sched);
 
@@ -12583,27 +15252,32 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
 
 static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
     LM_GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL;
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
 }
 
 static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
     LM_GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN;
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
 }
 
 static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
     LM_GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL;
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
 }
 
 static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
     LM_GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE;
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
 }
 
 static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) {
     LM_GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED;
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
+}
+
+static bool llama_is_unused_token(const llama_vocab& vocab, llama_token id) {
+    LM_GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED;
 }
 
 static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
@@ -12611,7 +15285,8 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
     LM_GGML_ASSERT(llama_is_byte_token(vocab, id));
     const auto & token_data = vocab.id_to_token.at(id);
     switch (llama_vocab_get_type(vocab)) {
-        case LLAMA_VOCAB_TYPE_SPM: {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
             auto buf = token_data.text.substr(3, 2);
             return strtol(buf.c_str(), NULL, 16);
         }
@@ -12631,7 +15306,8 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
     LM_GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
     static const char * hex = "0123456789ABCDEF";
     switch (llama_vocab_get_type(vocab)) {
-        case LLAMA_VOCAB_TYPE_SPM: {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
             const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
             auto token = vocab.token_to_id.find(buf);
             if (token != vocab.token_to_id.end()) {
@@ -12835,107 +15511,144 @@ struct llm_bigram_bpe {
 };
 
 struct llm_tokenizer_bpe {
-    llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {}
-
-    void tokenize(const std::string & text, std::vector & output) {
-        int final_prev_index = -1;
-        bool ignore_merges = false;
-
-        std::vector word_collection;
-        switch (vocab.type) {
-            case LLAMA_VOCAB_TYPE_BPE:
-                switch (vocab.type_pre) {
-                    case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
-                        ignore_merges = true;
-                        word_collection = unicode_regex_split(text, {
-                            // original regex from tokenizer.json
-                            //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-
-                            // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
-                            "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_DBRX:
-                    case LLAMA_VOCAB_PRE_TYPE_SMAUG:
-                        word_collection = unicode_regex_split(text, {
-                            // same as llama3
-                            "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
-                        word_collection = unicode_regex_split(text, {
-                            "[\r\n]",
-                            "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
-                            "\\s?[!-/:-~!-/:-~‘-‟ -。]+",
-                            "\\s+$",
-                            "[一-龥ࠀ-一가-퟿]+",
-                            "\\p{N}+",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
-                        word_collection = unicode_regex_split(text, {
-                            "[\r\n]",
-                            "\\s?\\p{L}+",
-                            "\\s?\\p{P}+",
-                            "[一-龥ࠀ-一가-퟿]+",
-                            "\\p{N}",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_FALCON:
-                        word_collection = unicode_regex_split(text, {
-                            "[\\p{P}\\$\\+<=>\\^~\\|]+",
-                            "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                            "[0-9][0-9][0-9]",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_MPT:
-                        // TODO: MPT pre-tokenization regexes are unknown
-                        //       the following are close, but not exact. run the following:
-                        //       ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
-                        LM_GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
-                        word_collection = unicode_regex_split(text, {
-                            "\\s?\\p{L}+",
-                            "\\s?\\p{P}+",
-                            "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_STARCODER:
-                    case LLAMA_VOCAB_PRE_TYPE_REFACT:
-                    case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
-                        word_collection = unicode_regex_split(text, {
-                            "\\p{N}",
-                            "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_GPT2:
-                    case LLAMA_VOCAB_PRE_TYPE_OLMO:
-                        word_collection = unicode_regex_split(text, {
-                            "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
-                    case LLAMA_VOCAB_PRE_TYPE_QWEN2:
-                        word_collection = unicode_regex_split(text, {
-                            // original regex from tokenizer.json
-                            // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
-                            "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-                        });
-                        break;
-                    default:
-                        // default regex for BPE tokenization pre-processing
-                        word_collection = unicode_regex_split(text, {
-                            "[\\p{P}\\$\\+<=>\\^~\\|]+",
-                            "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                            "\\p{N}+",
-                            "[0-9][0-9][0-9]",
-                        });
-                        break;
-                }
+    llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
+        LM_GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
+        switch (vocab.type_pre) {
+            case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
+                regex_exprs = {
+                    // original regex from tokenizer.json
+                    //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+
+                    // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
+                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_DBRX:
+            case LLAMA_VOCAB_PRE_TYPE_SMAUG:
+                regex_exprs = {
+                    // same as llama3
+                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
+                regex_exprs = {
+                    "[\r\n]",
+                    "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
+                    "\\s?[!-/:-~!-/:-~‘-‟ -。]+",
+                    "\\s+$",
+                    "[一-龥ࠀ-一가-퟿]+",
+                    "\\p{N}+",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
+                regex_exprs = {
+                    "[\r\n]",
+                    "\\s?\\p{L}+",
+                    "\\s?\\p{P}+",
+                    "[一-龥ࠀ-一가-퟿]+",
+                    "\\p{N}",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_FALCON:
+                regex_exprs = {
+                    "[\\p{P}\\$\\+<=>\\^~\\|`]+",
+                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+                    "[0-9][0-9][0-9]",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_STARCODER:
+            case LLAMA_VOCAB_PRE_TYPE_REFACT:
+            case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
+                regex_exprs = {
+                    "\\p{N}",
+                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_GPT2:
+            case LLAMA_VOCAB_PRE_TYPE_MPT:
+            case LLAMA_VOCAB_PRE_TYPE_OLMO:
+            case LLAMA_VOCAB_PRE_TYPE_JAIS:
+                regex_exprs = {
+                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
+            case LLAMA_VOCAB_PRE_TYPE_QWEN2:
+                regex_exprs = {
+                    // original regex from tokenizer.json
+                    // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
+                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_PORO:
+                regex_exprs = {
+                    " ?[^(\\s|.,!?…。,、।۔،)]+",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_CHATGLM4:
+                regex_exprs = {
+                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_VIKING:
+                regex_exprs = {
+                    " ?[^(\\s|.,!?…。,、।۔،)]+",
+                    "\\p{N}",
+                };
                 break;
             default:
-                LM_GGML_ASSERT(false);
+                // default regex for BPE tokenization pre-processing
+                regex_exprs = {
+                    "[\\p{P}\\$\\+<=>\\^~\\|]+",
+                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+                    "\\p{N}+",
+                    "[0-9][0-9][0-9]",
+                };
                 break;
         }
+    }
+
+    void append(const llama_vocab::id token_id, std::vector & output) const {
+        output.push_back(token_id);
+    }
+
+    bool append_bos(std::vector & output) const {
+        if (vocab.tokenizer_add_bos) {
+            LM_GGML_ASSERT(vocab.special_bos_id != -1);
+            output.push_back(vocab.special_bos_id);
+            return true;
+        }
+        return false;
+    }
+
+    bool append_eos(std::vector & output) const {
+        if (vocab.tokenizer_add_eos) {
+            LM_GGML_ASSERT(vocab.special_eos_id != -1);
+            output.push_back(vocab.special_eos_id);
+            return true;
+        }
+        return false;
+    }
+
+    void check_double_bos_eos(const std::vector & output) const {
+        if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
+            LLAMA_LOG_WARN(
+                "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                "Are you sure this is what you want?\n", __FUNCTION__);
+        }
+        if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) {
+            LLAMA_LOG_WARN(
+                "%s: Added a EOS token to the prompt as specified by the model but the prompt "
+                "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
+                "Are you sure this is what you want?\n", __FUNCTION__);
+        }
+    }
+
+    void tokenize(const std::string & text, std::vector & output) {
+        int final_prev_index = -1;
+
+        const auto word_collection = unicode_regex_split(text, regex_exprs);
 
         symbols_final.clear();
 
@@ -12946,7 +15659,7 @@ struct llm_tokenizer_bpe {
             int index = 0;
             size_t offset = 0;
 
-            if (ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
+            if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
                 symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
                 offset = word.size();
             }
@@ -13027,10 +15740,9 @@ struct llm_tokenizer_bpe {
                     for (auto j = str.begin(); j != str.end(); ++j) {
                         std::string byte_str(1, *j);
                         auto token_multibyte = vocab.token_to_id.find(byte_str);
-                        if (token_multibyte == vocab.token_to_id.end()) {
-                            throw std::runtime_error("ERROR: byte not found in vocab");
+                        if (token_multibyte != vocab.token_to_id.end()) {
+                            output.push_back(token_multibyte->second);
                         }
-                        output.push_back((*token_multibyte).second);
                     }
                 } else {
                     output.push_back((*token).second);
@@ -13069,6 +15781,8 @@ struct llm_tokenizer_bpe {
 
     const llama_vocab & vocab;
 
+    std::vector regex_exprs;
+
     std::vector symbols;
     std::vector symbols_final;
 
@@ -13078,7 +15792,7 @@ struct llm_tokenizer_bpe {
 struct llm_tokenizer_wpm {
     llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
 
-    void tokenize(const std::string & text, std::vector & output) {
+    void tokenize(const std::string & text, std::vector & output) const {
         const auto & token_map = vocab.token_to_id;
 
         // normalize and split by whitespace
@@ -13087,7 +15801,7 @@ struct llm_tokenizer_wpm {
         // bos token prepended already
 
         // find the longest tokens that form the words
-        for (const std::string &word : words) {
+        for (const std::string & word : words) {
             // skip empty words
             if (word.size() == 0) {
                 continue;
@@ -13104,7 +15818,7 @@ struct llm_tokenizer_wpm {
             for (int i = 0; i < n; ++i) {
                 // loop through possible match length
                 bool match = false;
-                for (int j = n; j > i; j--) {
+                for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) {
                     auto it = token_map.find(word1.substr(i, j - i));
                     if (it != token_map.end()) {
                         output.push_back(it->second);
@@ -13120,68 +15834,449 @@ struct llm_tokenizer_wpm {
                 }
             }
 
-            // we didn't find any matches for this word
-            if (current_tokens == output.size()) {
-                output.push_back(vocab.special_unk_id);
-            }
+            // we didn't find any matches for this word
+            if (current_tokens == output.size()) {
+                output.push_back(vocab.special_unk_id);
+            }
+        }
+    }
+
+    // TODO: reduce string copies by using cpts_offs array
+    std::vector preprocess(const std::string & text) const {
+        const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
+        std::vector words(1, "");
+
+        for (const uint32_t cpt : cpts_nfd) {
+            const auto flags = unicode_cpt_flags(cpt);
+
+            if (flags.is_whitespace) {
+                if (words.back().size()) {  // finish previous word if any
+                    words.emplace_back();
+                }
+                continue;
+            }
+
+            assert (!flags.is_separator);
+            if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
+                continue;
+            }
+
+            const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
+            if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
+                if (words.back().size()) {  // finish previous word if any
+                    words.emplace_back();
+                }
+                words.back() = s;       // single char word
+                words.emplace_back();   // start a new word
+            } else {
+                words.back() += s;  // append char to word
+            }
+        }
+
+        if (!words.back().size()) {
+            words.pop_back();
+        }
+
+        return words;
+    }
+
+    static bool is_chinese_char(uint32_t cpt) {
+        return
+            (cpt >= 0x04E00 && cpt <= 0x09FFF) ||
+            (cpt >= 0x03400 && cpt <= 0x04DBF) ||
+            (cpt >= 0x20000 && cpt <= 0x2A6DF) ||
+            (cpt >= 0x2A700 && cpt <= 0x2B73F) ||
+            (cpt >= 0x2B740 && cpt <= 0x2B81F) ||
+            (cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
+            (cpt >= 0x0F900 && cpt <= 0x0FAFF) ||
+            (cpt >= 0x2F800 && cpt <= 0x2FA1F);
+            //(cpt >= 0x3000  && cpt <= 0x303F)  ||
+            //(cpt >= 0xFF00  && cpt <= 0xFFEF);
+    }
+
+    const llama_vocab & vocab;
+};
+
+struct naive_trie {
+    naive_trie() : has_value(false), value(0) {
+    }
+    void insert(const char * key, size_t len, int32_t value = 0) {
+        if (len == 0) {
+            this->has_value = true;
+            this->value = value;
+            return;
+        }
+        char c = key[0];
+        auto res = children.find(c);
+        if (res != children.end()) {
+            res->second.insert(key + 1, len - 1, value);
+        } else {
+            auto res = children.insert(std::make_pair(c, naive_trie()));
+            res.first->second.insert(key + 1, len - 1, value);
+        }
+    }
+    std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) {
+        if (len == 0 || offset == len) {
+            return std::make_pair(key, offset);
+        }
+        char c = key[offset];
+        auto res = children.find(c);
+        if (res != children.end()) {
+            return res->second.get_longest_prefix(key, len, offset + 1);
+        } else {
+            return std::make_pair(key, offset);
+        }
+    }
+    struct naive_trie * traverse(const char c) {
+        auto res = children.find(c);
+        if (res != children.end()) {
+            return &res->second;
+        } else {
+            return NULL;
+        }
+    }
+    std::map children;
+    bool has_value;
+    llama_token value;
+};
+
+struct llm_tokenizer_ugm {
+    llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) {
+        if (vocab.precompiled_charsmap.size() > 0) {
+            size_t charsmap_offset = 0;
+
+            // First four bytes of precompiled_charsmap contains length of binary
+            // blob containing XOR-compressed compact double array (XCDA) entries
+            uint32_t xcda_blob_size = *(const uint32_t *) &vocab.precompiled_charsmap[0];
+            charsmap_offset += sizeof(xcda_blob_size);
+            if (xcda_blob_size + charsmap_offset >= vocab.precompiled_charsmap.size()) {
+                throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
+            }
+
+            // Next xcda_blob_size bytes contain entries of XOR-compressed compact
+            // double array (XCDA). Each entry is bit-packed into a 32-bit integer.
+            xcda_array = (const uint32_t *) &vocab.precompiled_charsmap[charsmap_offset];
+            xcda_array_size = xcda_blob_size / sizeof(uint32_t);
+            charsmap_offset += xcda_blob_size;
+
+            // Remaining bytes of precompiled charsmap contain null-terminated
+            // replacement strings for prefixes matched by the XCDA.
+            prefix_replacements = &vocab.precompiled_charsmap[charsmap_offset];
+            prefix_replacements_size = vocab.precompiled_charsmap.size() - charsmap_offset;
+        }
+
+        for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
+            const auto &token_data = vocab.id_to_token[id];
+
+            if (llama_is_normal_token(vocab, id)) {
+                min_score = std::min(min_score, token_data.score);
+                max_score = std::max(max_score, token_data.score);
+            }
+
+            if (llama_is_normal_token(vocab, id) ||
+                llama_is_user_defined_token(vocab, id) ||
+                llama_is_unused_token(vocab, id)) {
+                token_matcher.insert(token_data.text.data(), token_data.text.size(), id);
+            }
+
+            if (llama_is_user_defined_token(vocab, id)) {
+                user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size());
+            }
+        }
+
+        unknown_token_score = min_score - unknown_token_score_penalty;
+    }
+
+    /* This implementation is based on SentencePiece optimized Viterbi algorithm for
+     * unigram language models. The general idea is to:
+     * - move along the input sequence in steps of one UTF code point,
+     * - at each step find all possible tokenizations of the prefix by
+     *   traversing the tokens trie,
+     * - for each tokenization store the best one so far (by higher score)
+     * - use the position in sequence after given token as an index to store
+     *   results
+     * - if there was no valid tokenization of the current UTF code point
+     *   then use unknown token with additional score penalty
+     * After processing the whole sequence we backtrack from the end to get
+     * the best tokenization.
+    */
+    void tokenize(const std::string & text, std::vector & output) {
+        // normalize the input first
+        std::string normalized;
+        normalize(text, &normalized);
+        size_t input_len = normalized.size();
+        if (input_len == 0) {
+            return;
+        }
+
+        // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
+        std::vector tokenization_results(input_len + 1, {vocab.special_unk_id, 0, -FLT_MAX});
+        // at the beginning tokenization score is zero
+        tokenization_results[0] = { vocab.special_unk_id, 0, 0 };
+
+        for (size_t input_offset = 0; input_offset < input_len;) {
+            size_t prefix_offset = input_offset;
+            // calculate how many code units are in the currently processed UTF code point
+            size_t n_utf8_code_units = std::min(utf8_len(normalized[input_offset]), input_len - input_offset);
+
+            // traverse the token matcher trie to find a matching token
+            bool single_codepoint_token_found = false;
+            const struct best_tokenization & current_best = tokenization_results[input_offset];
+            struct naive_trie * node  = token_matcher.traverse(normalized[prefix_offset++]);
+
+            while (prefix_offset <= input_len && node != NULL) {
+                // check if we found valid token in prefix
+                if (node->has_value) {
+                    // check if it corresponds to the whole UTF code point
+                    if (prefix_offset - input_offset == n_utf8_code_units) {
+                        single_codepoint_token_found = true;
+                    }
+                    llama_token token_id = node->value;
+                    const auto & token_data = vocab.id_to_token[token_id];
+
+                    // we set the user-defined token scores to 0 to make them more likely to be selected
+                    // (normal token scores are log probabilities, so they are negative)
+                    // score type is double here to make tokenization results exactly
+                    // the same as in the HF tokenizer using SentencePiece
+                    const double token_score = llama_is_user_defined_token(vocab, token_id) ? 0.0 : token_data.score;
+                    const double challenger_score = current_best.score_sum + token_score;
+                    struct best_tokenization & current_champ = tokenization_results[prefix_offset];
+                    if (challenger_score > current_champ.score_sum) {
+                        struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score };
+                        current_champ = challenger;
+                    }
+                }
+                node = node->traverse(normalized[prefix_offset++]);
+            }
+
+            // if we didn't find a valid token corresponding to the whole UTF code point
+            // then use unknown token as the tokenization of this UTF code point
+            if (!single_codepoint_token_found) {
+                const double challenger_score = current_best.score_sum + unknown_token_score;
+                prefix_offset = input_offset + n_utf8_code_units;
+                struct best_tokenization & current_champ = tokenization_results[prefix_offset];
+                if (challenger_score > current_champ.score_sum) {
+                    struct best_tokenization challenger = { vocab.special_unk_id, input_offset, (float) challenger_score };
+                    current_champ = challenger;
+                }
+            }
+
+            // move to the next UTF code point
+            input_offset += n_utf8_code_units;
+        }
+
+        // now backtrack from the end to gather token ids of the best tokenization
+        // merge sequences of consecutive unknown tokens into single unknown tokens
+        bool is_prev_unknown = false;
+        for (struct best_tokenization & tokenization = tokenization_results[input_len]; ; tokenization = tokenization_results[tokenization.input_offset]) {
+            bool is_unknown = tokenization.token_id == vocab.special_unk_id;
+            if (!(is_prev_unknown && is_unknown)) {
+                output.push_back(tokenization.token_id);
+            }
+            if (tokenization.input_offset == 0) {
+                break;
+            }
+            is_prev_unknown = is_unknown;
+        }
+
+        // reverse the output since we added tokens starting from the end of the input
+        std::reverse(output.begin(), output.end());
+    }
+
+private:
+    const llama_vocab & vocab;
+
+    // helper structure for returning normalization results
+    struct normalization_result {
+        const char * normalized;
+        size_t normalized_len;
+        size_t consumed_input;
+    };
+
+    void normalize(const std::string& input, std::string * normalized) {
+        normalized->clear();
+        normalized->reserve(input.size() * 3);
+
+        const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " ";
+
+        bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
+        bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
+        bool shall_merge_spaces = vocab.tokenizer_remove_extra_whitespaces;
+
+        bool is_space_prepended = false;
+        bool processing_non_ws = false;
+
+        size_t input_len = input.size();
+
+        for (size_t input_offset = 0; input_offset < input_len; ) {
+            auto norm_res = normalize_prefix(input, input_offset);
+            for (size_t i = 0; i < norm_res.normalized_len; i++) {
+                char c = norm_res.normalized[i];
+                if (c != ' ') {
+                    if (!processing_non_ws) {
+                        processing_non_ws = true;
+                        if ((shall_prepend_space && !is_space_prepended) || shall_merge_spaces) {
+                            normalized->append(space);
+                            is_space_prepended = true;
+                        }
+                    }
+                    normalized->push_back(c);
+                } else {
+                    if (processing_non_ws) {
+                        processing_non_ws = false;
+                    }
+                    if (!shall_merge_spaces) {
+                        normalized->append(space);
+                    }
+                }
+            }
+
+            input_offset += norm_res.consumed_input;
+        }
+
+        if (shall_append_space) {
+            normalized->append(space);
+        }
+    }
+
+    /*
+     * This structure is a view wrapper for XOR-compressed double array (XCDA)
+     * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries.
+     * Eeach bit-packed entry contains:
+     * - BASE array value in bits 10-30
+     * - LCHECK array value in bits 0-7
+     * - LEAF array value in bit 9
+     * Entries containing indexes of replacement sequences have set bit 31
+     */
+    struct xcda_array_view {
+    public:
+        xcda_array_view(const uint32_t * xcda_array, size_t xcda_array_size) : xcda_array(xcda_array), xcda_array_size(xcda_array_size) {
+        }
+        uint32_t get_base(size_t index) {
+            uint32_t packed_node = get_node(index);
+            return (packed_node >> 10) << ((packed_node & (1U << 9)) >> 6);
+        }
+        uint32_t get_lcheck(size_t index) {
+            uint32_t packed_node = get_node(index);
+            return packed_node & ((1U << 31) | 0xff);
+        }
+        bool get_leaf(size_t index) {
+            uint32_t packed_node = get_node(index);
+            return (packed_node >> 8) & 1;
+        }
+        uint32_t get_value(size_t index) {
+            uint32_t packed_node = get_node(index);
+            return packed_node & ((1U << 31) - 1);
+        }
+    private:
+        uint32_t get_node(size_t index) {
+            if (index > xcda_array_size) {
+                throw std::runtime_error("Index out of array bounds in XCDA array!");
+            }
+            return xcda_array[index];
+        }
+        const uint32_t * xcda_array;
+        size_t xcda_array_size;
+    };
+
+    struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
+        if (input_offset == input.size()) {
+            return { &input[input_offset], 0, 0 };
         }
-    }
 
-    std::vector preprocess(const std::string & text) {
-        const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
-        std::vector words(1, "");
+        // if input prefix matches some user-defined token return this token as normalization result
+        auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
+        if (user_defined_token_match.second > 0) {
+            return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
+        }
 
-        for (const char32_t cpt : cpts_nfd) {
-            const auto flags = unicode_cpt_flags(cpt);
+        size_t longest_prefix_length = 0;
+        size_t longest_prefix_offset = 0;
 
-            if (flags.is_whitespace) {
-                if (words.back().size()) {  // finish previous word if any
-                    words.emplace_back();
+        if (xcda_array_size > 0) {
+            struct xcda_array_view xcda_view(xcda_array, xcda_array_size);
+
+            // Find the longest normalized sequence matching the input prefix by walking
+            // the XOR-compressed compact double array (XCDA) starting from the root node
+            // We find the index of the next node by calculating BASE[s] ^ c where s is
+            // the index of the previous node and c is a numerical character value
+            uint32_t node_index = 0;
+            // get BASE of the root node
+            node_index = xcda_view.get_base(node_index);
+            for (size_t prefix_offset = input_offset; prefix_offset < input.size(); prefix_offset++) {
+                unsigned char c = input[prefix_offset];
+                if (c == 0) {
+                    break;
+                }
+                node_index ^= c;
+                // if value of LCHECK is not c it means that this is not a child of
+                // the previous node, so we stop matching
+                if (xcda_view.get_lcheck(node_index) != c) {
+                    break;
+                }
+                bool is_leaf = xcda_view.get_leaf(node_index);
+                // get BASE of the current node
+                node_index ^= xcda_view.get_base(node_index);
+                // if LEAF of the current node is true, it means that its BASE points to the node
+                // containing index of replacement sequence for currently matched input prefix
+                if (is_leaf)
+                {
+                    longest_prefix_length = prefix_offset - input_offset + 1;
+                    // get index of replacement sequence for currently matched input prefix
+                    longest_prefix_offset = xcda_view.get_value(node_index);
                 }
-                continue;
             }
+        }
 
-            assert (!flags.is_separator);
-            if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
-                continue;
+        if (longest_prefix_length > 0) {
+            // we have a match, so return the replacement sequence
+            if (longest_prefix_offset >= prefix_replacements_size) {
+                throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
             }
-
-            const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
-            if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
-                if (words.back().size()) {  // finish previous word if any
-                    words.emplace_back();
-                }
-                words.back() = s;       // single char word
-                words.emplace_back();   // start a new word
-            } else {
-                words.back() += s;  // append char to word
+            const char * prefix_replacement = &prefix_replacements[longest_prefix_offset];
+            return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
+        } else {
+            // check if the input prefix contains a valid sequence of UTF-8 code units
+            try {
+                // if yes, return this sequence unmodified
+                size_t prefix_offset = input_offset;
+                unicode_cpt_from_utf8(input, prefix_offset);
+                return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
+            } catch (std::invalid_argument & /*ex*/) {
+                // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
+                return { "\xEF\xBF\xBD", 3, 1 };
             }
         }
+    }
 
-        if (!words.back().size()) {
-            words.pop_back();
-        }
+    // escaped space symbol - U+2581 (Lower One Eighth Block)
+    const std::string escaped_space = "\xE2\x96\x81";
 
-        return words;
-    }
+    const char * prefix_replacements = NULL;
+    size_t prefix_replacements_size = 0;
 
-    static bool is_chinese_char(uint32_t cpt) {
-        return
-            (cpt >= 0x04E00 && cpt <= 0x09FFF) ||
-            (cpt >= 0x03400 && cpt <= 0x04DBF) ||
-            (cpt >= 0x20000 && cpt <= 0x2A6DF) ||
-            (cpt >= 0x2A700 && cpt <= 0x2B73F) ||
-            (cpt >= 0x2B740 && cpt <= 0x2B81F) ||
-            (cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
-            (cpt >= 0x0F900 && cpt <= 0x0FAFF) ||
-            (cpt >= 0x2F800 && cpt <= 0x2FA1F);
-            //(cpt >= 0x3000  && cpt <= 0x303F)  ||
-            //(cpt >= 0xFF00  && cpt <= 0xFFEF);
-    }
+    const uint32_t * xcda_array = NULL;
+    size_t xcda_array_size = 0;
 
-    const llama_vocab & vocab;
+    struct naive_trie user_defined_token_matcher;
+
+    // this structure stores the best tokenization so far at input_offset
+    struct best_tokenization {
+        llama_token token_id;
+        size_t input_offset;
+        float score_sum;
+    };
+
+    float min_score = FLT_MAX;
+    float max_score = -FLT_MAX;
+
+    float unknown_token_score_penalty = 10.0;
+    float unknown_token_score;
+
+    struct naive_trie token_matcher;
 };
 
+
 typedef enum FRAGMENT_BUFFER_VARIANT_TYPE {
     FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN,
     FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT
@@ -13218,10 +16313,19 @@ struct fragment_buffer_variant {
 
 // #define PRETOKENIZERDEBUG
 
-static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer) {
+static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer, bool parse_special) {
     // for each special token
-    for (const llama_vocab::id special_id : vocab.special_tokens_cache) {
-        const auto & special_token = vocab.id_to_token[special_id].text;
+    for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
+        const auto & data = vocab.id_to_token[special_id];
+        const auto & special_token = data.text;
+
+        if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) {
+            // Ignore control and unknown tokens when parse_special == false
+            continue;
+            // User-defined tokens are still pre-tokenized before everything else
+            // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
+            // This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.)
+        }
 
         // for each text fragment
         std::forward_list::iterator it = buffer.begin();
@@ -13258,13 +16362,22 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                     if (match > raw_text_base_offset) {
                         // left
                         const int64_t left_reminder_offset = raw_text_base_offset + 0;
-                        const int64_t left_reminder_length = match - raw_text_base_offset;
-                        buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
+                        int64_t left_reminder_length = match - raw_text_base_offset;
+
+                        if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) {
+                            while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) {
+                                left_reminder_length--;
+                            }
+                        }
+
+                        if (left_reminder_length > 0) {
+                            buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
+                            it++;
+                        }
 
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
 #endif
-                        it++;
                     }
 
                     // special token
@@ -13273,16 +16386,25 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
 
                     // right
                     if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
-                        const int64_t right_reminder_offset = match + special_token.length();
-                        const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
-                        buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
+                        int64_t right_reminder_offset = match + special_token.length();
+                        int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
+
+                        if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
+                            while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
+                                right_reminder_offset++;
+                                right_reminder_length--;
+                            }
+                        }
+
+                        if (right_reminder_length > 0) {
+                            buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
+                            it++;
+                        }
 
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
 #endif
 
-                        it++;
-
                         if (source == 0) {
                             buffer.erase_after(buffer.before_begin());
                         } else {
@@ -13317,7 +16439,7 @@ static std::vector llama_tokenize_internal(const llama_vocab &
 
     if (!raw_text.empty()) {
         fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
-        if (parse_special) tokenizer_st_partition(vocab, fragment_buffer);
+        tokenizer_st_partition(vocab, fragment_buffer, parse_special);
     }
 
     switch (vocab.type) {
@@ -13328,11 +16450,9 @@ static std::vector llama_tokenize_internal(const llama_vocab &
                 // tokenizer.encode('', add_special_tokens=True)  returns [1]
                 // tokenizer.encode('', add_special_tokens=False) returns []
 
-                static const bool rtrim = true;  //TODO: as param
-                bool is_prev_special = false;
-                bool special_token_rtrim = false;
+                bool is_prev_special = true;  // prefix with space if first token
 
-                if (add_special && vocab.special_add_bos != 0) {
+                if (add_special && vocab.tokenizer_add_bos) {
                     LM_GGML_ASSERT(vocab.special_bos_id != -1);
                     output.push_back(vocab.special_bos_id);
                     is_prev_special = true;
@@ -13340,29 +16460,11 @@ static std::vector llama_tokenize_internal(const llama_vocab &
 
                 for (const auto & fragment : fragment_buffer) {
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        // without adding this leading whitespace, we do not get the same results as the original tokenizer
-
-                        // TODO: It's likely possible to get rid of this string copy entirely
-                        //  by modifying llm_tokenizer_x to operate with string offsets like pre-tokenizer
-                        //  and passing 'add space prefix' as bool argument
-                        //
                         auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
 
-                        if (special_token_rtrim) {
-                            size_t num_whitespaces = 0;
-                            while (isspace(raw_text[num_whitespaces])) {
-                                num_whitespaces++;
-                            }
-                            if (num_whitespaces == raw_text.size()) {
-                                continue; // skip if all whitespaces
-                            }
-                            raw_text = raw_text.substr(num_whitespaces);
-                        }
-
-                        if (vocab.add_space_prefix) {
-                            if (!output.size() || is_prev_special) {  // prefix with space if first token
-                                raw_text = " " + raw_text;
-                            }
+                        // prefix with space if previous is special
+                        if (vocab.tokenizer_add_space_prefix && is_prev_special) {
+                            raw_text = " " + raw_text;
                         }
 
 #ifdef PRETOKENIZERDEBUG
@@ -13371,36 +16473,32 @@ static std::vector llama_tokenize_internal(const llama_vocab &
                         llm_tokenizer_spm tokenizer(vocab);
                         llama_escape_whitespace(raw_text);
                         tokenizer.tokenize(raw_text, output);
+                        is_prev_special = false;
                     } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
                         output.push_back(fragment.token);
                         is_prev_special = true;
-                        // phi-3 special tokens without rtrim, works fine for llama-spm too
-                        special_token_rtrim = rtrim
-                            && fragment.token != vocab.special_bos_id
-                            && fragment.token != vocab.special_unk_id
-                            && fragment.token != vocab.special_eos_id;
                     }
                 }
 
-                if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
+                if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
                     LLAMA_LOG_WARN(
                         "%s: Added a BOS token to the prompt as specified by the model but the prompt "
                         "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
                         "Are you sure this is what you want?\n", __FUNCTION__);
                 }
 
-                if (add_special && vocab.special_add_eos == 1) {
+                if (add_special && vocab.tokenizer_add_eos) {
                     LM_GGML_ASSERT(vocab.special_eos_id != -1);
                     output.push_back(vocab.special_eos_id);
                 }
             } break;
         case LLAMA_VOCAB_TYPE_BPE:
             {
-                if (add_special && vocab.special_add_bos != 0) {
-                    LM_GGML_ASSERT(vocab.special_bos_id != -1);
-                    output.push_back(vocab.special_bos_id);
-                }
+                llm_tokenizer_bpe tokenizer(vocab);
 
+                if (add_special) {
+                    tokenizer.append_bos(output);
+                }
                 for (const auto & fragment : fragment_buffer) {
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@@ -13408,23 +16506,15 @@ static std::vector llama_tokenize_internal(const llama_vocab &
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
 #endif
-                        llm_tokenizer_bpe tokenizer(vocab);
                         tokenizer.tokenize(raw_text, output);
                     } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
+                        tokenizer.append(fragment.token, output);
                     }
                 }
 
-                if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-                    LLAMA_LOG_WARN(
-                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                        "Are you sure this is what you want?\n", __FUNCTION__);
-                }
-
-                if (add_special && vocab.special_add_eos == 1) {
-                    LM_GGML_ASSERT(vocab.special_add_eos != -1);
-                    output.push_back(vocab.special_eos_id);
+                if (add_special) {
+                    tokenizer.append_eos(output);
+                    tokenizer.check_double_bos_eos(output);
                 }
             } break;
         case LLAMA_VOCAB_TYPE_WPM:
@@ -13434,6 +16524,8 @@ static std::vector llama_tokenize_internal(const llama_vocab &
                     output.push_back(vocab.special_cls_id);
                 }
 
+                llm_tokenizer_wpm tokenizer(vocab);
+
                 for (const auto & fragment : fragment_buffer) {
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@@ -13441,7 +16533,6 @@ static std::vector llama_tokenize_internal(const llama_vocab &
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
 #endif
-                        llm_tokenizer_wpm tokenizer(vocab);
                         tokenizer.tokenize(raw_text, output);
                     } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
                         output.push_back(fragment.token);
@@ -13453,6 +16544,39 @@ static std::vector llama_tokenize_internal(const llama_vocab &
                     output.push_back(vocab.special_sep_id);
                 }
             } break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            {
+                llm_tokenizer_ugm tokenizer(vocab);
+
+                if (add_special && vocab.tokenizer_add_bos != 0) {
+                    LM_GGML_ASSERT(vocab.special_bos_id != -1);
+                    output.push_back(vocab.special_bos_id);
+                }
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
+#endif
+                        tokenizer.tokenize(raw_text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && vocab.tokenizer_add_eos == 1) {
+                    LM_GGML_ASSERT(vocab.special_eos_id != -1);
+                    output.push_back(vocab.special_eos_id);
+                }
+            } break;
         case LLAMA_VOCAB_TYPE_NONE:
             LM_GGML_ASSERT(false);
     }
@@ -13541,7 +16665,7 @@ static std::pair llama_grammar_match_char(
         const uint32_t                chr) {
 
     bool found            = false;
-    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
+    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
 
     LM_GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
 
@@ -13550,6 +16674,10 @@ static std::pair llama_grammar_match_char(
             // inclusive range, e.g. [a-z]
             found = found || (pos->value <= chr && chr <= pos[1].value);
             pos += 2;
+        } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
+            // Any character matches "."
+            found = true;
+            pos += 1;
         } else {
             // exact char match, e.g. [a] or "a"
             found = found || pos->value == chr;
@@ -13567,7 +16695,7 @@ static bool llama_grammar_match_partial_char(
         const llama_grammar_element * pos,
         const llama_partial_utf8      partial_utf8) {
 
-    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
+    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
     LM_GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
 
     uint32_t partial_value = partial_utf8.value;
@@ -13597,6 +16725,9 @@ static bool llama_grammar_match_partial_char(
                 return is_positive_char;
             }
             pos += 2;
+        } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
+            // Any character matches "."
+            return true;
         } else {
             // exact char match, e.g. [a] or "a"
             if (low <= pos->value && pos->value <= high) {
@@ -13657,6 +16788,7 @@ static void llama_grammar_advance_stack(
         }
         case LLAMA_GRETYPE_CHAR:
         case LLAMA_GRETYPE_CHAR_NOT:
+        case LLAMA_GRETYPE_CHAR_ANY:
             if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
                 // only add the stack if it's not a duplicate of one we already have
                 new_stacks.emplace_back(stack);
@@ -13859,7 +16991,8 @@ struct llama_grammar * llama_grammar_init(
             continue;
         }
         if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
-            throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %zu", i));
+            LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
+            return nullptr;
         }
     }
 
@@ -14379,7 +17512,7 @@ void llama_sample_repetition_penalties(
 
 void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
     LM_GGML_ASSERT(ctx);
-    const int64_t t_start_sample_us = lm_ggml_time_us();
+    int64_t t_start_sample_us = lm_ggml_time_us();
 
     bool allow_eog = false;
     for (const auto & stack : grammar->stacks) {
@@ -14391,12 +17524,13 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
 
     std::vector, llama_partial_utf8>> candidates_decoded;
     candidates_decoded.reserve(candidates->size);
-    std::vector                              candidates_grammar;
+
+    std::vector candidates_grammar;
     candidates_grammar.reserve(candidates->size);
 
     for (size_t i = 0; i < candidates->size; ++i) {
-        const llama_token id    = candidates->data[i].id;
-        const std::string piece = llama_token_to_piece(ctx, id, false);
+        const llama_token id      = candidates->data[i].id;
+        const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id);
 
         if (llama_token_is_eog(&ctx->model, id)) {
             if (!allow_eog) {
@@ -14596,7 +17730,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
         LM_GGML_ASSERT(false);
     }
 
-    const std::string piece = llama_token_to_piece(ctx, token, false);
+    const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(token);
 
     // Note terminating 0 in decoded string
     const auto   decoded     = decode_utf8(piece, grammar->partial_utf8);
@@ -14612,260 +17746,6 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
     ctx->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
 }
 
-//
-// Beam search
-//
-
-struct llama_beam {
-    std::vector tokens;
-    float p;  // Cumulative beam probability (renormalized relative to all beams)
-    bool eob; // Initialize end-of-beam to false. Callback sets this to true.
-    // Sort beams by probability. In case of ties, prefer beams at eob.
-    bool operator<(const llama_beam & rhs) const {
-        return std::make_pair(p, eob) < std::make_pair(rhs.p, rhs.eob);
-    }
-    // Shift off first n tokens and discard them.
-    void shift_tokens(const size_t n) {
-        if (n) {
-            std::copy(tokens.begin() + n, tokens.end(), tokens.begin());
-            tokens.resize(tokens.size() - n);
-        }
-    }
-    llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eob}; }
-};
-
-// A struct for calculating logit-related info.
-struct llama_logit_info {
-    const float * const logits;
-    const int n_vocab;
-    const float max_l;
-    const float normalizer;
-    struct sum_exp {
-        float max_l;
-        float operator()(float sum, float l) const { return sum + std::exp(l - max_l); }
-    };
-    llama_logit_info(llama_context * ctx)
-      : logits(llama_get_logits(ctx))
-      , n_vocab(llama_n_vocab(llama_get_model(ctx)))
-      , max_l(*std::max_element(logits, logits + n_vocab))
-      , normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l}))
-      { }
-    llama_token_data get_token_data(const llama_token token_id) const {
-        constexpr auto p = std::numeric_limits::quiet_NaN();  // never used
-        return {token_id, logits[token_id], p};
-    }
-    // Return top k token_data by logit.
-    std::vector top_k(size_t k) {
-        std::vector min_heap;  // min-heap by logit
-        const llama_token k_min = std::min(static_cast(k), n_vocab);
-        min_heap.reserve(k_min);
-        for (llama_token token_id = 0 ; token_id < k_min ; ++token_id) {
-            min_heap.push_back(get_token_data(token_id));
-        }
-        auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; };
-        std::make_heap(min_heap.begin(), min_heap.end(), comp);
-        for (llama_token token_id = k_min ; token_id < n_vocab ; ++token_id) {
-            if (min_heap.front().logit < logits[token_id]) {
-                std::pop_heap(min_heap.begin(), min_heap.end(), comp);
-                min_heap.back().id = token_id;
-                min_heap.back().logit = logits[token_id];
-                std::push_heap(min_heap.begin(), min_heap.end(), comp);
-            }
-        }
-        return min_heap;
-    }
-    float probability_from_logit(float logit) const {
-        return normalizer * std::exp(logit - max_l);
-    }
-};
-
-struct llama_beam_search_data {
-    llama_context * ctx;
-    size_t n_beams;
-    int n_past;
-    int n_predict;
-    std::vector beams;
-    std::vector next_beams;
-
-    // Re-calculated on each loop iteration
-    size_t common_prefix_length;
-
-    // Used to communicate to/from callback on beams state.
-    std::vector beam_views;
-
-    llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict)
-      : ctx(ctx)
-      , n_beams(n_beams)
-      , n_past(n_past)
-      , n_predict(n_predict)
-      , beam_views(n_beams) {
-        beams.reserve(n_beams);
-        next_beams.reserve(n_beams);
-    }
-
-    // Collapse beams to a single beam given by index.
-    void collapse_beams(const size_t beam_idx) {
-        if (0u < beam_idx) {
-            std::swap(beams[0], beams[beam_idx]);
-        }
-        beams.resize(1);
-    }
-
-    // Min-heaps are used to efficiently collect the top-k elements (k=n_beams).
-    // The repetitive patterns below reflect the 2 stages of heaps:
-    //  * Gather elements until the vector is full, then call std::make_heap() on it.
-    //  * If the heap is full and a new element is found that should be included, pop the
-    //    least element to the back(), replace it with the new, then push it into the heap.
-    void fill_next_beams_by_top_probabilities(llama_beam & beam) {
-        // Min-heaps use a greater-than comparator.
-        const auto comp = [](const llama_beam & a, const llama_beam & b) { return a.p > b.p; };
-        if (beam.eob) {
-            // beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
-            if (next_beams.size() < n_beams) {
-                next_beams.push_back(std::move(beam));
-                if (next_beams.size() == n_beams) {
-                    std::make_heap(next_beams.begin(), next_beams.end(), comp);
-                }
-            } else if (next_beams.front().p < beam.p) {
-                std::pop_heap(next_beams.begin(), next_beams.end(), comp);
-                next_beams.back() = std::move(beam);
-                std::push_heap(next_beams.begin(), next_beams.end(), comp);
-            }
-        } else {
-            // beam is not at end-of-sentence, so branch with next top_k tokens.
-            if (!beam.tokens.empty()) {
-                llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0));
-            }
-            llama_logit_info logit_info(ctx);
-            std::vector next_tokens = logit_info.top_k(n_beams);
-
-            // Clear the kv slot so that other beams may try different tokens at this position. The llama_decode()
-            // call in loop() will conclusively fill in the kv slot once the beams converge at this position.
-            llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
-
-            size_t i=0;
-            if (next_beams.size() < n_beams) {
-                for (; next_beams.size() < n_beams ; ++i) {
-                    llama_beam next_beam = beam;
-                    next_beam.tokens.push_back(next_tokens[i].id);
-                    next_beam.p *= logit_info.probability_from_logit(next_tokens[i].logit);
-                    next_beams.push_back(std::move(next_beam));
-                }
-                std::make_heap(next_beams.begin(), next_beams.end(), comp);
-            } else {
-                for (; next_beams.front().p == 0.0f ; ++i) {
-                    std::pop_heap(next_beams.begin(), next_beams.end(), comp);
-                    next_beams.back() = beam;
-                    next_beams.back().tokens.push_back(next_tokens[i].id);
-                    next_beams.back().p *= logit_info.probability_from_logit(next_tokens[i].logit);
-                    std::push_heap(next_beams.begin(), next_beams.end(), comp);
-                }
-            }
-            for (; i < n_beams ; ++i) {
-                const float next_p = beam.p * logit_info.probability_from_logit(next_tokens[i].logit);
-                if (next_beams.front().p < next_p) {
-                    std::pop_heap(next_beams.begin(), next_beams.end(), comp);
-                    next_beams.back() = beam;
-                    next_beams.back().tokens.push_back(next_tokens[i].id);
-                    next_beams.back().p = next_p;
-                    std::push_heap(next_beams.begin(), next_beams.end(), comp);
-                }
-            }
-        }
-    }
-
-    // Find common_prefix_length based on beams.
-    // Requires beams is not empty.
-    size_t find_common_prefix_length() {
-        size_t common_prefix_length = beams[0].tokens.size();
-        for (size_t i = 1 ; i < beams.size() ; ++i) {
-            common_prefix_length = std::min(common_prefix_length, beams[i].tokens.size());
-            for (size_t j = 0 ; j < common_prefix_length ; ++j) {
-                if (beams[0].tokens[j] != beams[i].tokens[j]) {
-                    common_prefix_length = j;
-                    break;
-                }
-            }
-        }
-        return common_prefix_length;
-    }
-
-    // Construct beams_state to send back to caller via the callback function.
-    // Side effect: set common_prefix_length = find_common_prefix_length();
-    llama_beams_state get_beams_state(const bool last_call) {
-        for (size_t i = 0 ; i < beams.size() ; ++i) {
-            beam_views[i] = beams[i].view();
-        }
-        common_prefix_length = find_common_prefix_length();
-        return {beam_views.data(), beams.size(), common_prefix_length, last_call};
-    }
-
-    // Loop:
-    //  * while i < n_predict, AND
-    //  * any of the beams have not yet reached end-of-beam (eob), AND
-    //  * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence
-    //    (since all other beam probabilities can only decrease)
-    void loop(const llama_beam_search_callback_fn_t callback, void * const callback_data) {
-        beams.push_back({{}, 1.0f, false});  // Start with one empty beam w/ probability = 1.0 and !eob.
-        const auto not_eob = [](const llama_beam & beam) { return !beam.eob; };
-        for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eob) &&
-                       !beams[top_beam_index()].eob ; ++i) {
-            callback(callback_data, get_beams_state(false));  // Sets common_prefix_length
-            update_beams_from_beam_views();   // Update values (p,eob) that callback may have changed.
-            if (common_prefix_length) {
-                llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0));
-                n_past += common_prefix_length;
-            }
-            // Zero-out next_beam probabilities to place them last in following min-heap.
-            std::for_each(next_beams.begin(), next_beams.end(), [](llama_beam & beam) { beam.p = 0.0f; });
-            for (llama_beam & beam : beams) {
-                beam.shift_tokens(common_prefix_length);
-                fill_next_beams_by_top_probabilities(beam);
-            }
-            // next_beams become the beams of next/final iteration. Swap them to re-use memory.
-            beams.swap(next_beams);
-            renormalize_beam_probabilities(beams);
-        }
-        collapse_beams(top_beam_index());
-        callback(callback_data, get_beams_state(true));
-    }
-
-    // As beams grow, the cumulative probabilities decrease.
-    // Renormalize them to avoid floating point underflow.
-    static void renormalize_beam_probabilities(std::vector & beams) {
-        const auto sum_p = [](float sum, llama_beam & beam) { return sum + beam.p; };
-        const float inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p);
-        std::for_each(beams.begin(), beams.end(), [=](llama_beam & beam) { beam.p *= inv_sum; });
-    }
-
-    // Assumes beams is non-empty.  Uses llama_beam::operator<() for ordering.
-    size_t top_beam_index() {
-        return std::max_element(beams.begin(), beams.end()) - beams.begin();
-    }
-
-    // Copy (p,eob) for each beam which may have been changed by the callback.
-    void update_beams_from_beam_views() {
-        for (size_t i = 0 ; i < beams.size() ; ++i) {
-            beams[i].p = beam_views[i].p;
-            beams[i].eob = beam_views[i].eob;
-        }
-    }
-};
-
-void llama_beam_search(llama_context * ctx,
-                       llama_beam_search_callback_fn_t callback, void * callback_data,
-                       size_t n_beams, int n_past, int n_predict) {
-    assert(ctx);
-    const int64_t t_start_sample_us = lm_ggml_time_us();
-
-    llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict);
-
-    beam_search_data.loop(callback, callback_data);
-
-    ctx->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
-    ctx->n_sample++;
-}
-
 //
 // quantization
 //
@@ -14977,8 +17857,8 @@ static lm_ggml_type llama_tensor_get_type(quantize_state_internal & qs, lm_ggml_
     const llm_arch arch = qs.model.arch;
     const auto       tn = LLM_TN(arch);
 
-    auto use_more_bits = [](int i_layer, int num_layers) -> bool {
-        return i_layer < num_layers/8 || i_layer >= 7*num_layers/8 || (i_layer - num_layers/8)%3 == 2;
+    auto use_more_bits = [](int i_layer, int n_layers) -> bool {
+        return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
     };
     const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
     auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
@@ -15030,6 +17910,10 @@ static lm_ggml_type llama_tensor_get_type(quantize_state_internal & qs, lm_ggml_
             else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
                 new_type = LM_GGML_TYPE_IQ3_S;
             }
+            else if (new_type == LM_GGML_TYPE_Q4_0_4_4 || new_type == LM_GGML_TYPE_Q4_0_4_8 ||
+                     new_type == LM_GGML_TYPE_Q4_0_8_8) {
+                new_type = LM_GGML_TYPE_Q4_0;
+            }
         }
     } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
                ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M    || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
@@ -15213,10 +18097,10 @@ static lm_ggml_type llama_tensor_get_type(quantize_state_internal & qs, lm_ggml_
     //    if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = LM_GGML_TYPE_Q4_K;
     //}
     bool convert_incompatible_tensor = false;
-    if (new_type == LM_GGML_TYPE_Q2_K || new_type == LM_GGML_TYPE_Q3_K || new_type == LM_GGML_TYPE_Q4_K ||
-        new_type == LM_GGML_TYPE_Q5_K || new_type == LM_GGML_TYPE_Q6_K || new_type == LM_GGML_TYPE_IQ4_XS ||
-        new_type == LM_GGML_TYPE_IQ2_XS || new_type == LM_GGML_TYPE_IQ2_XXS || new_type == LM_GGML_TYPE_IQ2_S ||
-        new_type == LM_GGML_TYPE_IQ3_XXS || new_type == LM_GGML_TYPE_IQ1_S || new_type == LM_GGML_TYPE_IQ3_S ||
+    if (new_type == LM_GGML_TYPE_Q2_K    || new_type == LM_GGML_TYPE_Q3_K    || new_type == LM_GGML_TYPE_Q4_K   ||
+        new_type == LM_GGML_TYPE_Q5_K    || new_type == LM_GGML_TYPE_Q6_K    || new_type == LM_GGML_TYPE_IQ4_XS ||
+        new_type == LM_GGML_TYPE_IQ2_XS  || new_type == LM_GGML_TYPE_IQ2_XXS || new_type == LM_GGML_TYPE_IQ2_S  ||
+        new_type == LM_GGML_TYPE_IQ3_XXS || new_type == LM_GGML_TYPE_IQ1_S   || new_type == LM_GGML_TYPE_IQ3_S  ||
         new_type == LM_GGML_TYPE_IQ1_M) {
         int nx = tensor->ne[0];
         int ny = tensor->ne[1];
@@ -15342,6 +18226,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         case LLAMA_FTYPE_MOSTLY_IQ4_XS:  default_type = LM_GGML_TYPE_IQ4_XS;  break;
         case LLAMA_FTYPE_MOSTLY_IQ3_S:   default_type = LM_GGML_TYPE_IQ3_S;   break;
         case LLAMA_FTYPE_MOSTLY_IQ3_M:   default_type = LM_GGML_TYPE_IQ3_S;   break;
+        case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: default_type = LM_GGML_TYPE_Q4_0_4_4; break;
+        case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: default_type = LM_GGML_TYPE_Q4_0_4_8; break;
+        case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: default_type = LM_GGML_TYPE_Q4_0_8_8; break;
 
         default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
     }
@@ -15383,6 +18270,14 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         if (imatrix_data) {
             LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size()));
             qs.has_imatrix = true;
+            // check imatrix for nans or infs
+            for (const auto & kv : *imatrix_data) {
+                for (float f : kv.second) {
+                    if (!std::isfinite(f)) {
+                        throw std::runtime_error(format("imatrix contains non-finite value %f\n", f));
+                    }
+                }
+            }
         }
     }
 
@@ -15434,10 +18329,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 
     // sanity checks
     //
-    //  - qs.n_attention_wv == 0                     for Mamba       models
-    //  - qs.n_attention_wv == model.hparams.n_layer for Transformer models
+    //  - qs.n_attention_wv == 0                         for Mamba           models
+    //  - qs.n_attention_wv == model.hparams.n_layer     for Transformer     models
+    //  - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models
     //
-    LM_GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer) && "n_attention_wv is unexpected");
+    LM_GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer) && "n_attention_wv is unexpected");
 
     size_t total_size_org = 0;
     size_t total_size_new = 0;
@@ -15562,6 +18458,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         quantize &= name.find("ssm_x.weight")      == std::string::npos;
         quantize &= name.find("ssm_dt.weight")     == std::string::npos;
 
+        // do not quantize relative position bias (T5)
+        quantize &= name.find("attn_rel_b.weight") == std::string::npos;
+
         enum lm_ggml_type new_type;
         void * new_data;
         size_t new_size;
@@ -15640,6 +18539,14 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
                 f32_data = (float *) f32_conv_buf.data();
             }
 
+            int chunk_size_multiplier = 1;
+            if (new_type == LM_GGML_TYPE_Q4_0_4_4 || new_type == LM_GGML_TYPE_Q4_0_4_8 || new_type == LM_GGML_TYPE_Q4_0_8_8) {
+                if ((new_type == LM_GGML_TYPE_Q4_0_8_8) && (tensor->ne[1] % 8 != 0)) new_type = LM_GGML_TYPE_Q4_0;
+                else if (tensor->ne[1] % 4 != 0) new_type = LM_GGML_TYPE_Q4_0;
+                if (new_type == LM_GGML_TYPE_Q4_0_8_8) chunk_size_multiplier = 8;
+                else if (new_type == LM_GGML_TYPE_Q4_0_4_4 || new_type == LM_GGML_TYPE_Q4_0_4_8) chunk_size_multiplier = 4;
+            }
+
             LLAMA_LOG_INFO("converting to %s .. ", lm_ggml_type_name(new_type));
             fflush(stdout);
 
@@ -15652,7 +18559,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
             const int64_t nrows = tensor->ne[1];
 
             static const int64_t min_chunk_size = 32 * 512;
-            const int64_t chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row);
+            const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)) *
+                                       chunk_size_multiplier;
 
             const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
             const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
@@ -15666,310 +18574,238 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
                 const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
 
                 new_size += llama_tensor_quantize_internal(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
-            }
-            LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", lm_ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
-        }
-        total_size_org += lm_ggml_nbytes(tensor);
-        total_size_new += new_size;
-
-        // update the gguf meta data as we go
-        lm_gguf_set_tensor_type(ctx_outs[cur_split], name.c_str(), new_type);
-        lm_gguf_set_tensor_data(ctx_outs[cur_split], name.c_str(), new_data, new_size);
-
-        // write tensor data + padding
-        fout.write((const char *) new_data, new_size);
-        zeros(fout, LM_GGML_PAD(new_size, align) - new_size);
-    }
-    close_ofstream();
-    for (auto & c:ctx_outs) {
-        lm_gguf_free(c);
-    }
-
-    LLAMA_LOG_INFO("%s: model size  = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
-    LLAMA_LOG_INFO("%s: quant size  = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
-
-    if (qs.n_fallback > 0) {
-        LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
-                __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback);
-    }
-}
-
-static int llama_apply_lora_from_file_internal(
-    const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads
-) {
-    LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
-
-    const int64_t t_start_lora_us = lm_ggml_time_us();
-
-    llama_file fin(path_lora, "rb");
-
-    // verify magic and version
-    {
-        uint32_t magic = fin.read_u32();
-        if (magic != LLAMA_FILE_MAGIC_GGLA) {
-            LLAMA_LOG_ERROR("%s: bad file magic\n", __func__);
-            return 1;
-        }
-
-        uint32_t format_version = fin.read_u32();
-        if (format_version != 1) {
-            LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ );
-            return 1;
-        }
-    }
-
-    int32_t lora_r = fin.read_u32();
-    int32_t lora_alpha = fin.read_u32();
-    float scaling = scale * (float)lora_alpha / (float)lora_r;
-
-    LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
-
-    // load base model
-    std::unique_ptr ml;
-    if (path_base_model) {
-        LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
-        ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr));
-        ml->init_mappings(/*prefetch*/ false); // no prefetching
-    }
-
-    struct tensor_meta {
-        std::string name;
-        lm_ggml_type type;
-        int32_t ne[2];
-        size_t offset;
-    };
-    std::map tensor_meta_map;
-
-    // load all tensor meta
-    while (true) {
-        if (fin.tell() == fin.size) {
-            // eof
-            break;
-        }
-
-        int32_t n_dims;
-        int32_t name_len;
-        int32_t ftype;
-
-        fin.read_raw(&n_dims, sizeof(n_dims));
-        fin.read_raw(&name_len, sizeof(name_len));
-        fin.read_raw(&ftype, sizeof(ftype));
-
-        if (n_dims != 1 && n_dims != 2) {
-            LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
-            return 1;
-        }
-
-        int32_t ne[2] = { 1, 1 };
-        for (int i = 0; i < n_dims; ++i) {
-            fin.read_raw(&ne[i], sizeof(ne[i]));
-        }
-
-        std::string name;
-        {
-            LM_GGML_ASSERT(name_len < LM_GGML_MAX_NAME);
-            char buf[LM_GGML_MAX_NAME];
-            fin.read_raw(buf, name_len);
-            name = std::string(buf, name_len);
-        }
-
-        // check for lora suffix
-        std::string lora_suffix;
-        if (name.length() > 6) {
-            lora_suffix = name.substr(name.length() - 6);
-        }
-        if (lora_suffix != ".loraA" && lora_suffix != ".loraB") {
-            LLAMA_LOG_ERROR("%s: error: '%s' is not a lora tensor\n", __func__, name.c_str());
-            return 1;
+            }
+            LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", lm_ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
         }
+        total_size_org += lm_ggml_nbytes(tensor);
+        total_size_new += new_size;
 
-        // tensor type
-        lm_ggml_type wtype;
-        switch (ftype) {
-            case 0: wtype = LM_GGML_TYPE_F32;  break;
-            case 1: wtype = LM_GGML_TYPE_F16;  break;
-            default:
-                    {
-                        LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n",
-                                __func__, ftype);
-                        return 1;
-                    }
-        }
+        // update the gguf meta data as we go
+        lm_gguf_set_tensor_type(ctx_outs[cur_split], name.c_str(), new_type);
+        lm_gguf_set_tensor_data(ctx_outs[cur_split], name.c_str(), new_data, new_size);
 
-        // data offset
-        size_t offset = fin.tell();
-        offset = (offset + 31) & -32;
+        // write tensor data + padding
+        fout.write((const char *) new_data, new_size);
+        zeros(fout, LM_GGML_PAD(new_size, align) - new_size);
+    }
+    close_ofstream();
+    for (auto & c:ctx_outs) {
+        lm_gguf_free(c);
+    }
 
-        // skip tensor data
-        fin.seek(offset + lm_ggml_row_size(wtype, ne[0]) * ne[1], SEEK_SET);
+    LLAMA_LOG_INFO("%s: model size  = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
+    LLAMA_LOG_INFO("%s: quant size  = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
 
-        tensor_meta_map.emplace(name, tensor_meta{ name, wtype, { ne[0], ne[1] }, offset });
+    if (qs.n_fallback > 0) {
+        LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
+                __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback);
     }
+}
 
-    bool warned = false;
-    int n_tensors = 0;
+static void llama_lora_adapter_init_internal(struct llama_model * model, const char * path_lora, struct llama_lora_adapter & adapter) {
+    LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
 
-    // apply
-    lm_ggml_backend_t backend_cpu = lm_ggml_backend_cpu_init();
-    if (backend_cpu == nullptr) {
-        LLAMA_LOG_ERROR("%s: error: failed to initialize cpu backend\n", __func__);
-        return 1;
+    lm_ggml_context * ctx = nullptr;
+    struct lm_gguf_init_params meta_lm_gguf_params = {
+        /* .no_alloc = */ true,
+        /* .ctx      = */ &ctx,
+    };
+    struct lm_gguf_context * ctx_gguf = lm_gguf_init_from_file(path_lora, meta_lm_gguf_params);
+    if (!ctx_gguf) {
+        throw std::runtime_error("failed to load lora adapter file from " + std::string(path_lora));
     }
-    lm_ggml_backend_cpu_set_n_threads(backend_cpu, n_threads);
 
-    std::vector> read_buf;
-    for (const auto & it : model.tensors_by_name) {
-        const std::string & base_name = it.first;
-        lm_ggml_tensor * model_t = it.second;
+    // check metadata
+    {
+        auto get_kv_str = [&](const std::string & key) -> std::string {
+            int id = lm_gguf_find_key(ctx_gguf, key.c_str());
+            return id < 0 ? "" : std::string(lm_gguf_get_val_str(ctx_gguf, id));
+        };
+        auto get_kv_f32 = [&](const std::string & key) -> float {
+            int id = lm_gguf_find_key(ctx_gguf, key.c_str());
+            return id < 0 ? 0.0f : lm_gguf_get_val_f32(ctx_gguf, id);
+        };
+        LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
 
-        if (tensor_meta_map.find(base_name + ".loraA") == tensor_meta_map.end() ||
-            tensor_meta_map.find(base_name + ".loraB") == tensor_meta_map.end()) {
-            continue;
+        auto general_type = get_kv_str(llm_kv(LLM_KV_GENERAL_TYPE));
+        if (general_type != "adapter") {
+            lm_gguf_free(ctx_gguf);
+            throw std::runtime_error("expect general.type to be 'adapter', but got: " + general_type);
         }
 
-        tensor_meta & metaA = tensor_meta_map.at(base_name + ".loraA");
-        tensor_meta & metaB = tensor_meta_map.at(base_name + ".loraB");
-
-        lm_ggml_init_params lora_init_params = {
-            /* .mem_size   */ lm_ggml_tensor_overhead()*128 + lm_ggml_graph_overhead(),
-            /* .mem_buffer */ nullptr,
-            /* .no_alloc   */ true,
-        };
-        lm_ggml_context * lora_ctx = lm_ggml_init(lora_init_params);
-        if (lora_ctx == nullptr) {
-            LLAMA_LOG_ERROR("%s: error: failed to initialize lora context\n", __func__);
-            lm_ggml_backend_free(backend_cpu);
-            return 1;
+        auto general_arch_str = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE));
+        auto general_arch = llm_arch_from_string(general_arch_str);
+        if (general_arch != model->arch) {
+            lm_gguf_free(ctx_gguf);
+            throw std::runtime_error("model arch and LoRA arch mismatch");
         }
 
-        // create tensors
-        lm_ggml_tensor * loraA = lm_ggml_new_tensor_2d(lora_ctx, metaA.type, metaA.ne[0], metaA.ne[1]);
-        lm_ggml_tensor * loraB = lm_ggml_new_tensor_2d(lora_ctx, metaB.type, metaB.ne[0], metaB.ne[1]);
-        lm_ggml_set_name(loraA, metaA.name.c_str());
-        lm_ggml_set_name(loraB, metaB.name.c_str());
-
-        lm_ggml_tensor * base_t;
-        if (ml) {
-            if (!ml->get_tensor_meta(base_name.c_str())) {
-                LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
-                return 1;
-            }
-            base_t = lm_ggml_dup_tensor(lora_ctx, ml->get_tensor_meta(base_name.c_str()));
-        } else {
-            base_t = lm_ggml_dup_tensor(lora_ctx, model_t);
+        auto adapter_type = get_kv_str(llm_kv(LLM_KV_ADAPTER_TYPE));
+        if (adapter_type != "lora") {
+            lm_gguf_free(ctx_gguf);
+            throw std::runtime_error("expect adapter.type to be 'lora', but got: " + adapter_type);
         }
-        lm_ggml_set_name(base_t, base_name.c_str());
 
-        // allocate in backend buffer
-        lm_ggml_backend_buffer_t lora_buf = lm_ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, lm_ggml_backend_cpu_buffer_type());
-        if (lora_buf == nullptr) {
-            LLAMA_LOG_ERROR("%s: error: failed to allocate lora tensors\n", __func__);
-            return 1;
-        }
+        adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA));
+    }
 
-        // load tensor data
-        auto load_tensor = [&read_buf, &fin](const tensor_meta & tensor_meta, lm_ggml_tensor * tensor) {
-            read_buf.resize(lm_ggml_nbytes(tensor));
-            fin.seek(tensor_meta.offset, SEEK_SET);
-            fin.read_raw(read_buf.data(), lm_ggml_nbytes(tensor));
-            lm_ggml_backend_tensor_set(tensor, read_buf.data(), 0, read_buf.size());
+    int n_tensors = lm_gguf_get_n_tensors(ctx_gguf);
+
+    // contexts for each buffer type
+    std::map ctx_map;
+    auto get_ctx_for_buft = [&](lm_ggml_backend_buffer_type_t buft) -> lm_ggml_context * {
+        auto it = ctx_map.find(buft);
+        if (it == ctx_map.end()) {
+            // add a new context
+            struct lm_ggml_init_params params = {
+                /*.mem_size   =*/ n_tensors*lm_ggml_tensor_overhead(),
+                /*.mem_buffer =*/ NULL,
+                /*.no_alloc   =*/ true,
+            };
+            lm_ggml_context * buft_ctx = lm_ggml_init(params);
+            ctx_map[buft] = buft_ctx;
+            return buft_ctx;
         };
-        load_tensor(metaA, loraA);
-        load_tensor(metaB, loraB);
+        return it->second;
+    };
 
-        // load base model tensor data
-        if (ml) {
-            ml->load_data_for(base_t);
+    // bundle lora_a and lora_b into pairs
+    std::map ab_map;
+    auto str_endswith = [](const std::string & str, const std::string & suffix) {
+        return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
+    };
+    for (lm_ggml_tensor * cur = lm_ggml_get_first_tensor(ctx); cur; cur = lm_ggml_get_next_tensor(ctx, cur)) {
+        std::string name(cur->name);
+        if (str_endswith(name, ".lora_a")) {
+            replace_all(name, ".lora_a", "");
+            if (ab_map.find(name) == ab_map.end()) {
+                ab_map[name] = llama_lora_weight(cur, nullptr);
+            } else {
+                ab_map[name].a = cur;
+            }
+        } else if (str_endswith(name, ".lora_b")) {
+            replace_all(name, ".lora_b", "");
+            if (ab_map.find(name) == ab_map.end()) {
+                ab_map[name] = llama_lora_weight(nullptr, cur);
+            } else {
+                ab_map[name].b = cur;
+            }
         } else {
-            lm_ggml_backend_tensor_copy(model_t, base_t);
+            lm_gguf_free(ctx_gguf);
+            lm_ggml_free(ctx);
+            throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix");
         }
+    }
 
-        if (lm_ggml_is_quantized(base_t->type) && !warned) {
-            LLAMA_LOG_WARN("%s: warning: using a lora adapter with a quantized model may result in poor quality, "
-                            "use a f16 or f32 base model with --lora-base\n", __func__);
-            warned = true;
-        }
+    // add tensors
+    for (auto & it : ab_map) {
+        const std::string & name = it.first;
+        llama_lora_weight & w = it.second;
 
-        if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
-            LLAMA_LOG_ERROR("%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
-                            " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]);
-            lm_ggml_free(lora_ctx);
-            lm_ggml_backend_buffer_free(lora_buf);
-            lm_ggml_backend_free(backend_cpu);
-            return 1;
+        if (!w.a || !w.b) {
+            lm_gguf_free(ctx_gguf);
+            lm_ggml_free(ctx);
+            throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component");
         }
 
-        auto build_lora_graph = [&]() {
-            // w = w + BA*s
-            lm_ggml_tensor * BA = lm_ggml_mul_mat(lora_ctx, loraA, loraB);
-            lm_ggml_set_name(BA, "BA");
-
-            if (scaling != 1.0f) {
-                BA = lm_ggml_scale(lora_ctx, BA, scaling);
-                lm_ggml_set_name(BA, "BA_scaled");
-            }
-
-            lm_ggml_tensor * r;
-            r = lm_ggml_add_inplace(lora_ctx, base_t, BA);
-            lm_ggml_set_name(r, "r_add");
+        // device buft and device ctx
+        auto * model_tensor = llama_get_model_tensor(model, name.c_str());
+        if (!model_tensor) {
+            lm_gguf_free(ctx_gguf);
+            lm_ggml_free(ctx);
+            throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model");
+        }
+        struct lm_ggml_context * dev_ctx = get_ctx_for_buft(lm_ggml_backend_buffer_get_type(model_tensor->buffer));
+        // validate tensor shape
+        if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) {
+            lm_gguf_free(ctx_gguf);
+            lm_ggml_free(ctx);
+            throw std::runtime_error("tensor '" + name + "' has incorrect shape");
+        }
+        if (w.a->ne[1] != w.b->ne[0]) {
+            lm_gguf_free(ctx_gguf);
+            lm_ggml_free(ctx);
+            throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)");
+        }
+        // save tensor to adapter
+        struct lm_ggml_tensor * tensor_a = lm_ggml_dup_tensor(dev_ctx, w.a);
+        struct lm_ggml_tensor * tensor_b = lm_ggml_dup_tensor(dev_ctx, w.b);
+        lm_ggml_set_name(tensor_a, w.a->name);
+        lm_ggml_set_name(tensor_b, w.b->name);
+        adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b);
+    }
 
-            if (base_t->type != model_t->type) {
-                // convert the result to the model type
-                r = lm_ggml_cast(lora_ctx, r, model_t->type);
-                lm_ggml_set_name(r, "r_cast");
+    // allocate tensors / buffers and zero
+    {
+        adapter.ctxs.reserve(ctx_map.size());
+        adapter.bufs.reserve(ctx_map.size());
+        for (auto it : ctx_map) {
+            lm_ggml_backend_buffer_type_t buft = it.first;
+            lm_ggml_context * ctx_dev = it.second;
+            lm_ggml_backend_buffer_t buf = lm_ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft);
+            if (!buf) {
+                lm_gguf_free(ctx_gguf);
+                lm_ggml_free(ctx);
+                throw std::runtime_error("failed to allocate buffer for lora adapter\n");
             }
+            LLAMA_LOG_INFO("%s: %10s LoRA buffer size = %8.2f MiB\n", __func__, lm_ggml_backend_buffer_name(buf), lm_ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
+            adapter.ctxs.push_back(ctx_dev);
+            adapter.bufs.push_back(buf);
+        }
+    }
 
-            return r;
+    // set tensor data
+    {
+        llama_file lm_gguf_file(path_lora, "rb");
+        std::vector read_buf;
+        auto set_tensor = [&](struct lm_ggml_tensor * orig, struct lm_ggml_tensor * dev) {
+            size_t offs = lm_gguf_get_data_offset(ctx_gguf) + lm_gguf_get_tensor_offset(ctx_gguf, lm_gguf_find_tensor(ctx_gguf, orig->name));
+            size_t size = lm_ggml_nbytes(orig);
+            read_buf.resize(size);
+            lm_gguf_file.seek(offs, SEEK_SET);
+            lm_gguf_file.read_raw(read_buf.data(), size);
+            lm_ggml_backend_tensor_set(dev, read_buf.data(), 0, size);
         };
-
-        lm_ggml_cgraph * gf = lm_ggml_new_graph(lora_ctx);
-        lm_ggml_tensor * r = build_lora_graph();
-        lm_ggml_build_forward_expand(gf, r);
-
-        lm_ggml_backend_buffer_t graph_buf = lm_ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, lm_ggml_backend_cpu_buffer_type());
-        if (graph_buf == nullptr) {
-            LLAMA_LOG_ERROR("%s: error: failed to allocate graph tensors\n", __func__);
-            lm_ggml_free(lora_ctx);
-            lm_ggml_backend_buffer_free(lora_buf);
-            lm_ggml_backend_free(backend_cpu);
-            return 1;
+        for (auto & it : adapter.ab_map) {
+            auto orig = ab_map[it.first];
+            auto dev  = it.second;
+            set_tensor(orig.a, dev.a);
+            set_tensor(orig.b, dev.b);
         }
+    }
 
-        lm_ggml_backend_graph_compute(backend_cpu, gf);
-
-        lm_ggml_backend_tensor_set(model_t, r->data, 0, lm_ggml_nbytes(r));
-
-#if 0
-        // TODO: use scheduler with fallback to CPU for less copies between CPU and GPU
-        //lm_ggml_backend_sched_t sched = lm_ggml_backend_sched_new(backends.data(), backends.size(), LM_GGML_DEFAULT_GRAPH_SIZE);
-
-        // sched compute
-        lm_ggml_build_forward_expand(gf, build_graph());
-        lm_ggml_backend_sched_init_measure(sched, gf);
-
-        // create the graph again, since the previous one was destroyed by the measure
-        lm_ggml_graph_clear(gf);
-        lm_ggml_build_forward_expand(gf, build_graph());
-        lm_ggml_backend_sched_graph_compute(sched, gf);
-        lm_ggml_backend_sched_free(sched);
-#endif
+    LLAMA_LOG_INFO("%s: loaded %ld tensors from lora file\n", __func__, adapter.ab_map.size()*2);
 
-        lm_ggml_backend_buffer_free(lora_buf);
-        lm_ggml_backend_buffer_free(graph_buf);
-        lm_ggml_free(lora_ctx);
+    // free ctx for reading gguf
+    lm_gguf_free(ctx_gguf);
+    lm_ggml_free(ctx);
+}
 
-        n_tensors++;
-        if (n_tensors % 4 == 0) {
-            LLAMA_LOG_INFO(".");
-        }
+int32_t llama_lora_adapter_set(
+            struct llama_context * ctx,
+            struct llama_lora_adapter * adapter,
+            float scale) {
+    if (ctx->cparams.flash_attn) {
+        LLAMA_LOG_ERROR("%s: flash_attn is not compatible with LoRA\n", __func__);
+        return -1;
     }
+    ctx->lora_adapters[adapter] = scale;
+    return 0;
+}
 
-    lm_ggml_backend_free(backend_cpu);
-
-    const int64_t t_lora_us = lm_ggml_time_us() - t_start_lora_us;
-    LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0);
+int32_t llama_lora_adapter_remove(
+            struct llama_context * ctx,
+            struct llama_lora_adapter * adapter) {
+    auto pos = ctx->lora_adapters.find(adapter);
+    if (pos != ctx->lora_adapters.end()) {
+        ctx->lora_adapters.erase(pos);
+        return 0;
+    }
+    return -1;
+}
 
-    return 0;
+void llama_lora_adapter_free(struct llama_lora_adapter * adapter) {
+    delete adapter;
 }
 
 //
@@ -16010,6 +18846,7 @@ struct llama_context_params llama_context_default_params() {
         /*.n_threads_batch             =*/ LM_GGML_DEFAULT_N_THREADS,
         /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
         /*.pooling_type                =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
+        /*.attention_type              =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
         /*.rope_freq_base              =*/ 0.0f,
         /*.rope_freq_scale             =*/ 0.0f,
         /*.yarn_ext_factor             =*/ -1.0f,
@@ -16062,6 +18899,8 @@ size_t llama_max_devices(void) {
     return LM_GGML_SYCL_MAX_DEVICES;
 #elif defined(LM_GGML_USE_VULKAN)
     return LM_GGML_VK_MAX_DEVICES;
+#elif defined(LM_GGML_USE_CANN)
+    return LM_GGML_CANN_MAX_DEVICES;
 #else
     return 1;
 #endif
@@ -16076,7 +18915,7 @@ bool llama_supports_mlock(void) {
 }
 
 bool llama_supports_gpu_offload(void) {
-#if defined(LM_GGML_USE_CUDA) || defined(LM_GGML_USE_CLBLAST) || defined(LM_GGML_USE_METAL) || defined(LM_GGML_USE_VULKAN) || \
+#if defined(LM_GGML_USE_CUDA) || defined(LM_GGML_USE_METAL)   || defined(LM_GGML_USE_VULKAN) || \
     defined(LM_GGML_USE_SYCL) || defined(LM_GGML_USE_KOMPUTE) || defined(LM_GGML_USE_RPC)
     // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
     return true;
@@ -16133,7 +18972,7 @@ struct llama_model * llama_load_model_from_file(
             return true;
         };
     }
-    if (params.rpc_servers != nullptr) {
+    if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
         // split the servers set them into model->rpc_servers
         std::string servers(params.rpc_servers);
         size_t pos = 0;
@@ -16187,6 +19026,22 @@ struct llama_context * llama_new_context_with_model(
         params.flash_attn = false;
     }
 
+    if (params.flash_attn && model->hparams.attn_soft_cap) {
+        LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
+        params.flash_attn = false;
+    }
+
+
+    if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
+        LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
+        params.flash_attn = false;
+    }
+
+    if (params.type_v != LM_GGML_TYPE_F16 && !params.flash_attn) {
+        LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
+        return nullptr;
+    }
+
     llama_context * ctx = new llama_context(*model);
 
     const auto & hparams = model->hparams;
@@ -16225,8 +19080,8 @@ struct llama_context * llama_new_context_with_model(
 
     cparams.n_ubatch         = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
 
-    cparams.n_yarn_orig_ctx  = params.yarn_orig_ctx    != 0 ? params.yarn_orig_ctx    :
-                               hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
+    cparams.n_ctx_orig_yarn  = params.yarn_orig_ctx    != 0 ? params.yarn_orig_ctx    :
+                               hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
                                                               hparams.n_ctx_train;
 
     cparams.cb_eval           = params.cb_eval;
@@ -16246,7 +19101,6 @@ struct llama_context * llama_new_context_with_model(
     }
 
     cparams.yarn_attn_factor *= hparams.rope_attn_factor;
-    cparams.causal_attn = hparams.causal_attn;
 
     if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
         if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
@@ -16256,6 +19110,12 @@ struct llama_context * llama_new_context_with_model(
         }
     }
 
+    if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
+        cparams.causal_attn = hparams.causal_attn;
+    } else {
+        cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
+    }
+
     if (params.seed == LLAMA_DEFAULT_SEED) {
         params.seed = time(NULL);
     }
@@ -16291,17 +19151,7 @@ struct llama_context * llama_new_context_with_model(
 
     if (!hparams.vocab_only) {
         // initialize backends
-#if defined(LM_GGML_USE_RPC)
-        for (auto & server : model->rpc_servers) {
-            lm_ggml_backend_t backend = lm_ggml_backend_rpc_init(server.c_str());
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to connect RPC backend to %s\n", __func__, server.c_str());
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        }
-#elif defined(LM_GGML_USE_METAL)
+#if defined(LM_GGML_USE_METAL)
         if (model->n_gpu_layers > 0) {
             ctx->backend_metal = lm_ggml_backend_metal_init();
             if (ctx->backend_metal == nullptr) {
@@ -16340,7 +19190,7 @@ struct llama_context * llama_new_context_with_model(
             return nullptr;
         }
         if (model->split_mode == LLAMA_SPLIT_MODE_NONE) {
-            lm_ggml_backend_t backend = lm_ggml_backend_vk_init(0);
+            lm_ggml_backend_t backend = lm_ggml_backend_vk_init(model->main_gpu);
             if (backend == nullptr) {
                 LLAMA_LOG_ERROR("%s: failed to initialize Vulkan backend\n", __func__);
                 llama_free(ctx);
@@ -16363,8 +19213,7 @@ struct llama_context * llama_new_context_with_model(
         if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
             lm_ggml_backend_t backend = lm_ggml_backend_sycl_init(model->main_gpu);
             if (backend == nullptr) {
-                int main_gpu_id = lm_ggml_backend_sycl_get_device_id(model->main_gpu);
-                LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d (index %d) backend\n", __func__, main_gpu_id, model->main_gpu);
+                LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d backend\n", __func__, model->main_gpu);
                 llama_free(ctx);
                 return nullptr;
             }
@@ -16393,6 +19242,53 @@ struct llama_context * llama_new_context_with_model(
             }
             ctx->backends.push_back(backend);
         }
+#elif defined(LM_GGML_USE_CANN)
+    // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
+    // TODO: lm_ggml_backend_cann is not support split tensor now, just leave code here.
+    if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
+        lm_ggml_backend_t backend = lm_ggml_backend_cann_init(model->main_gpu);
+        if (backend == nullptr) {
+            LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, model->main_gpu);
+            llama_free(ctx);
+            return nullptr;
+        }
+        ctx->backends.push_back(backend);
+    } else {
+        // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
+        // TODO: currently, CANN can't use multi-gpus, just leave code here for further cann version.
+        for (int32_t device = 0; device < lm_ggml_backend_cann_get_device_count(); ++device) {
+            lm_ggml_backend_t backend = lm_ggml_backend_cann_init(device);
+            if (backend == nullptr) {
+                LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, device);
+                llama_free(ctx);
+                return nullptr;
+            }
+            ctx->backends.push_back(backend);
+        }
+    }
+#endif
+
+#ifdef LM_GGML_USE_BLAS
+        ctx->backend_blas = lm_ggml_backend_blas_init();
+        if (ctx->backend_blas == nullptr) {
+            LLAMA_LOG_WARN("%s: failed to initialize BLAS backend\n", __func__);
+        } else {
+            ctx->backends.push_back(ctx->backend_blas);
+        }
+#endif
+
+#if defined(LM_GGML_USE_RPC)
+        if (model->n_gpu_layers > 0) {
+            for (const auto & endpoint : model->rpc_servers) {
+                lm_ggml_backend_t backend = lm_ggml_backend_rpc_init(endpoint.c_str());
+                if (backend == nullptr) {
+                    LLAMA_LOG_ERROR("%s: failed to initialize RPC to '%s'\n", __func__, endpoint.c_str());
+                    llama_free(ctx);
+                    return nullptr;
+                }
+                ctx->backends.push_back(backend);
+            }
+        }
 #endif
         ctx->backend_cpu = lm_ggml_backend_cpu_init();
         if (ctx->backend_cpu == nullptr) {
@@ -16545,6 +19441,8 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_BLOOM:
         case LLM_ARCH_MAMBA:
         case LLM_ARCH_JINA_BERT_V2:
+        case LLM_ARCH_T5:
+        case LLM_ARCH_JAIS:
             return LLAMA_ROPE_TYPE_NONE;
 
         // use what we call a normal RoPE, operating on pairs of consecutive head values
@@ -16561,6 +19459,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_OLMO:
         case LLM_ARCH_ARCTIC:
         case LLM_ARCH_DEEPSEEK2:
+        case LLM_ARCH_CHATGLM:
             return LLAMA_ROPE_TYPE_NORM;
 
         // the pairs of head values are offset by n_rot/2
@@ -16570,13 +19469,16 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_BERT:
         case LLM_ARCH_NOMIC_BERT:
         case LLM_ARCH_STABLELM:
+        case LLM_ARCH_BITNET:
         case LLM_ARCH_QWEN:
         case LLM_ARCH_QWEN2:
         case LLM_ARCH_QWEN2MOE:
         case LLM_ARCH_PHI2:
         case LLM_ARCH_PHI3:
         case LLM_ARCH_GEMMA:
+        case LLM_ARCH_GEMMA2:
         case LLM_ARCH_STARCODER2:
+        case LLM_ARCH_OPENELM:
         case LLM_ARCH_GPTNEOX:
             return LLAMA_ROPE_TYPE_NEOX;
 
@@ -16686,6 +19588,17 @@ struct lm_ggml_tensor * llama_get_model_tensor(struct llama_model * model, const
     return it->second;
 }
 
+bool llama_model_has_encoder(const struct llama_model * model) {
+    switch (model->arch) {
+        case LLM_ARCH_T5: return true;
+        default:          return false;
+    }
+}
+
+llama_token llama_model_decoder_start_token(const struct llama_model * model) {
+    return model->hparams.dec_start_token_id;
+}
+
 uint32_t llama_model_quantize(
         const char * fname_inp,
         const char * fname_out,
@@ -16699,12 +19612,14 @@ uint32_t llama_model_quantize(
     }
 }
 
-int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int32_t n_threads) {
+struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) {
     try {
-        return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads);
+        struct llama_lora_adapter * adapter = new llama_lora_adapter(model);
+        llama_lora_adapter_init_internal(model, path_lora, *adapter);
+        return adapter;
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
-        return 1;
+        return nullptr;
     }
 }
 
@@ -17020,7 +19935,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
     );
 
     // on session change it is very likely that the state size has changed - so we need to update this function
-    static_assert(LLAMA_SESSION_VERSION == 6, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?");
+    static_assert(LLAMA_SESSION_VERSION == 7, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?");
 
     return s_total;
 }
@@ -17159,8 +20074,6 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
         const auto & hparams = ctx->model.hparams;
 
         const uint32_t n_layer      = hparams.n_layer;
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
 
         // NOTE: kv_size and kv_buf_size are mostly used for sanity checks
         const uint32_t kv_head     = llama_kv_cache_cell_max(kv_self);
@@ -17180,6 +20093,9 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
 
             std::vector tmp_buf;
             for (int il = 0; il < (int) n_layer; ++il) {
+                const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
                 const size_t k_size = lm_ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
 
                 tmp_buf.resize(k_size);
@@ -17312,8 +20228,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
         const auto & hparams = ctx->model.hparams;
 
         const uint32_t n_layer      = hparams.n_layer;
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
 
         size_t   kv_buf_size;
         uint32_t kv_head;
@@ -17345,6 +20259,9 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
             LM_GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
 
             for (int il = 0; il < (int) n_layer; ++il) {
+                const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
                 const size_t k_size = lm_ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
 
                 lm_ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
@@ -17507,8 +20424,6 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
     const auto & hparams = ctx->model.hparams;
 
     const uint32_t n_layer = hparams.n_layer;
-    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
-    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
 
     for (uint32_t i = 0; i < kv_self.size; ++i) {
         const auto & cell = kv_self.cells[i];
@@ -17519,6 +20434,9 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
     }
 
     for (int il = 0; il < (int)n_layer; ++il) {
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
         // types of keys and values
         s_cell_data_size += sizeof(int32_t) * 2;
         // k_size_row and v_size_el values of layer
@@ -17593,14 +20511,15 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
 
     const auto & hparams = ctx->model.hparams;
     const uint32_t n_layer = hparams.n_layer;
-    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
-    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
 
     // Write the layer count
     data_ctx.write(&n_layer, sizeof(n_layer));
 
-    // Write n_embd_v_gqa
-    data_ctx.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
+    // Write n_embd_v_gqa (reference value)
+    {
+        const uint32_t n_embd_v_gqa_ref = hparams.n_embd_v_gqa() + hparams.n_embd_k_s();
+        data_ctx.write(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
+    }
 
     // Iterate the ranges and write all the pos (this is the token position in the prompt)
     for (const auto & range : cell_ranges) {
@@ -17614,6 +20533,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
     // Get whole range at a time
     std::vector tmp_buf;
     for (int il = 0; il < (int)n_layer; ++il) {
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+
         // Write key type
         const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
         data_ctx.write(&k_type_i, sizeof(k_type_i));
@@ -17634,6 +20555,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
     // TODO: simplify, reduce copy-paste
     if (!kv_self.v_trans) {
         for (int il = 0; il < (int)n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
             // Write value type
             const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
             data_ctx.write(&v_type_i, sizeof(v_type_i));
@@ -17654,6 +20577,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
         // For the values, they are transposed, so we also need the element size and get the element ranges from each row
         const uint32_t kv_size = kv_self.size;
         for (int il = 0; il < (int)n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
             // Write value type
             const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
             data_ctx.write(&v_type_i, sizeof(v_type_i));
@@ -17722,14 +20647,14 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
     // Sanity check model compatibility
     const auto & hparams = ctx->model.hparams;
     const uint32_t n_layer = hparams.n_layer;
-    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
-    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
+
     if (n_layer != n_layer_ref) {
         LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);
         return 0;
     }
-    if (n_embd_v_gqa != n_embd_v_gqa_ref) {
-        LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref);
+
+    if (hparams.n_embd_v_gqa() != n_embd_v_gqa_ref) {
+        LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, hparams.n_embd_v_gqa(), n_embd_v_gqa_ref);
         return 0;
     }
 
@@ -17769,6 +20694,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
 
     // For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo
     for (int il = 0; il < (int)n_layer; ++il) {
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+
         // Read type of key
         int32_t k_type_i_ref;
         memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref));
@@ -17801,6 +20728,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
     // TODO: simplify, reduce copy-paste
     if (!kv_self.v_trans) {
         for (int il = 0; il < (int)n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
             // Read type of value
             int32_t v_type_i_ref;
             memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
@@ -17832,6 +20761,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
     } else {
         // For each layer, read the values for each cell (transposed)
         for (int il = 0; il < (int)n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
             // Read type of value
             int32_t v_type_i_ref;
             memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
@@ -17969,6 +20900,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
     ctx->abort_callback_data = abort_callback_data;
 }
 
+void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
+    ctx->cparams.embeddings = embeddings;
+}
+
 void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
     ctx->cparams.causal_attn = causal_attn;
 }
@@ -18028,6 +20963,17 @@ void llama_batch_free(struct llama_batch batch) {
     if (batch.logits)   free(batch.logits);
 }
 
+int32_t llama_encode(
+        struct llama_context * ctx,
+          struct llama_batch   batch) {
+    const int ret = llama_encode_internal(*ctx, batch);
+    if (ret < 0) {
+        LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
+    }
+
+    return ret;
+}
+
 int32_t llama_decode(
         struct llama_context * ctx,
           struct llama_batch   batch) {
@@ -18175,9 +21121,9 @@ float llama_token_get_score(const struct llama_model * model, llama_token token)
     return model->vocab.id_to_token[token].score;
 }
 
-llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) {
+llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
     LM_GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return model->vocab.id_to_token[token].type;
+    return model->vocab.id_to_token[token].attr;
 }
 
 bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
@@ -18212,11 +21158,11 @@ llama_token llama_token_nl(const struct llama_model * model) {
 }
 
 int32_t llama_add_bos_token(const struct llama_model * model) {
-    return model->vocab.special_add_bos;
+    return model->vocab.tokenizer_add_bos;
 }
 
 int32_t llama_add_eos_token(const struct llama_model * model) {
-    return model->vocab.special_add_eos;
+    return model->vocab.tokenizer_add_eos;
 }
 
 llama_token llama_token_prefix(const struct llama_model * model) {
@@ -18235,6 +21181,10 @@ llama_token llama_token_eot(const struct llama_model * model) {
     return model->vocab.special_eot_id;
 }
 
+llama_token llama_token_pad(const struct llama_model * model) {
+    return model->vocab.special_pad_id;
+}
+
 int32_t llama_tokenize(
     const struct llama_model * model,
                   const char * text,
@@ -18244,7 +21194,6 @@ int32_t llama_tokenize(
                         bool   add_special,
                         bool   parse_special) {
     auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special);
-
     if (n_tokens_max < (int) res.size()) {
         // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
         return -((int) res.size());
@@ -18265,7 +21214,7 @@ static std::string llama_decode_text(const std::string & text) {
         const auto utf8 = unicode_cpt_to_utf8(cpt);
         try {
             decoded_text += unicode_utf8_to_byte(utf8);
-        } catch (const std::out_of_range & e) {
+        } catch (const std::out_of_range & /*e*/) {
             decoded_text += "[UNK_BYTE_0x";
             for (const auto c : utf8) {
                 decoded_text += format("%02x", (uint8_t) c);
@@ -18278,73 +21227,181 @@ static std::string llama_decode_text(const std::string & text) {
 }
 
 // does not write null-terminator to buf
-int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
+int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) {
+    // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
+    static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
+    const llama_token_attr attr = llama_token_get_attr(model, token);
+    if (!special && (attr & attr_special)) {
+        return 0;
+    }
+
+    // copy piece chars to output text buffer
+    // skip up to 'lstrip' leading spaces before copying
+    auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
+        for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
+            token++;
+            size--;
+        }
+        if (length < (int32_t)size) {
+            return -(int32_t) size;
+        }
+        memcpy(buf, token, size);
+        return (int32_t) size;
+    };
+
+    // if we have a cache - use it
+    {
+        const auto & cache = model->vocab.cache_token_to_piece;
+
+        if (!cache.empty()) {
+            const auto & result = cache.at(token);
+            return _try_copy(result.data(), result.size());
+        }
+    }
+
     if (0 <= token && token < llama_n_vocab(model)) {
+        const std::string & token_text = model->vocab.id_to_token[token].text;
         switch (llama_vocab_get_type(model->vocab)) {
-        case LLAMA_VOCAB_TYPE_WPM:
-        case LLAMA_VOCAB_TYPE_SPM: {
-            // NOTE: we accept all unsupported token types,
-            // suppressing them like CONTROL tokens.
-            if (llama_is_normal_token(model->vocab, token)) {
-                std::string result = model->vocab.id_to_token[token].text;
-                llama_unescape_whitespace(result);
-                if (length < (int) result.length()) {
-                    return -(int) result.length();
-                }
-                memcpy(buf, result.c_str(), result.length());
-                return result.length();
-            } else if (
-                    (llama_is_user_defined_token(model->vocab, token)) ||
-                    (llama_is_control_token     (model->vocab, token) && special)) {
-                std::string result = model->vocab.id_to_token[token].text;
-                if (length < (int) result.length()) {
-                    return -(int) result.length();
+            case LLAMA_VOCAB_TYPE_WPM:
+            case LLAMA_VOCAB_TYPE_SPM:
+            case LLAMA_VOCAB_TYPE_UGM: {
+                // NOTE: we accept all unsupported token types,
+                // suppressing them like CONTROL tokens.
+                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
+                    return _try_copy(token_text.data(), token_text.size());
+                } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+                    std::string result = token_text;
+                    llama_unescape_whitespace(result);
+                    return _try_copy(result.data(), result.size());
+                } else if (attr & LLAMA_TOKEN_ATTR_BYTE) {
+                    char byte = (char) llama_token_to_byte(model->vocab, token);
+                    return _try_copy((char*) &byte, 1);
                 }
-                memcpy(buf, result.c_str(), result.length());
-                return result.length();
-            } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
-                if (length < 3) {
-                    return -3;
-                }
-                memcpy(buf, "\xe2\x96\x85", 3);
-                return 3;
-            } else if (llama_is_byte_token(model->vocab, token)) {
-                if (length < 1) {
-                    return -1;
+                break;
+            }
+            case LLAMA_VOCAB_TYPE_BPE: {
+                // NOTE: we accept all unsupported token types,
+                // suppressing them like CONTROL tokens.
+                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
+                    return _try_copy(token_text.data(), token_text.size());
+                } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+                    std::string result = llama_decode_text(token_text);
+                    return _try_copy(result.data(), result.size());
                 }
-                buf[0] = llama_token_to_byte(model->vocab, token);
-                return 1;
+                break;
             }
-            break;
+            default:
+                LM_GGML_ASSERT(false);
         }
-        case LLAMA_VOCAB_TYPE_BPE: {
-            // NOTE: we accept all unsupported token types,
-            // suppressing them like CONTROL tokens.
-            if (llama_is_normal_token(model->vocab, token)) {
-                std::string result = model->vocab.id_to_token[token].text;
-                result = llama_decode_text(result);
-                if (length < (int) result.length()) {
-                    return -(int) result.length();
+    }
+    return 0;
+}
+
+int32_t llama_detokenize(
+        const struct llama_model * model,
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special) {
+    int32_t avail = text_len_max;
+    int32_t total = 0;
+
+    // remove the leading space
+    bool remove_space = model->vocab.tokenizer_add_space_prefix;
+
+    if (remove_special && model->vocab.tokenizer_add_bos) {
+        if (n_tokens > 0 && tokens[0] == model->vocab.special_bos_id) {
+            remove_space = false;
+            n_tokens--;
+            tokens++;
+        }
+    }
+
+    if (remove_special && model->vocab.tokenizer_add_eos) {
+        if (n_tokens > 0 && tokens[n_tokens-1] == model->vocab.special_eos_id) {
+            n_tokens--;
+        }
+    }
+
+    for (int32_t i = 0; i < n_tokens; ++i) {
+        LM_GGML_ASSERT(avail >= 0);
+        int32_t n_chars = llama_token_to_piece(model, tokens[i], text, avail, remove_space, unparse_special);
+        remove_space = false;
+        if (n_chars < 0) {
+            avail = 0;
+            total -= n_chars;
+        } else if (n_chars > 0) {
+            avail -= n_chars;
+            text  += n_chars;
+            total += n_chars;
+        }
+    }
+
+    if (total > text_len_max) {
+        return -total;
+    }
+
+    if (model->vocab.tokenizer_clean_spaces) {
+        text -= total;  // restart text
+
+        // first pass: characters ?!.,  //TODO: where do these characters come from?
+        const int32_t total1 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total1; ++i) {
+            const char x = text[i];
+            if (text[i - 1] == ' ') {
+                if (x == '?' || x == '!' || x == '.' || x == ',') {  // " ?", " !", " .", " ,"
+                    total--;  // remove space
                 }
-                memcpy(buf, result.c_str(), result.length());
-                return result.length();
-            } else if (
-                    (llama_is_user_defined_token(model->vocab, token)) ||
-                    (llama_is_control_token     (model->vocab, token) && special)) {
-                std::string result = model->vocab.id_to_token[token].text;
-                if (length < (int) result.length()) {
-                    return -(int) result.length();
+            }
+            text[total++] = x;
+        }
+
+        // second pass: strip single apostrophe between spaces
+        const int32_t total2 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total2; ++i) {
+            const char x = text[i];
+            if (x == '\'' && i + 1 < total2 && text[i - 1] == ' ' && text[i + 1] == ' ') {  // " ' "
+                total--;           // remove prev space
+                text[++i] = '\0';  // remove next space
+            }
+            text[total++] = x;
+        }
+
+        // third pass: apostrophe contractions  //NOTE: this makes sense?
+        const int32_t total3 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total3; ++i) {
+            const char x = text[i];
+            if (text[i - 1] == ' ') {
+                if (x == '\'' && i + 1 < total3) {
+                    const char x1 = text[i + 1];
+                    if (x1 == 't' || x1 == 'd') {  // " 't", " 'd"
+                        //total--;  // remove space
+                    } else if (x1 == 's' || x1 == 'm') {  // " 's", " 'm"
+                        total--;  // remove space
+                    } else if (i + 2 < total3) {
+                        const char x2 = text[i + 2];
+                        if ((x1 == 'l' && x2 == 'l')) {  // " 'll"
+                            //total--;  // remove space
+                        } else if ((x1 == 'r' && x2 == 'e') || (x1 == 'v' && x2 == 'e')) {  // " 're", " 've"
+                            total--;  // remove space
+                        } else {
+                            //total--;  // remove space
+                        }
+                    } else {
+                        //total--;  // remove space
+                    }
                 }
-                memcpy(buf, result.c_str(), result.length());
-                return result.length();
             }
-            break;
-        }
-        default:
-            LM_GGML_ASSERT(false);
+            text[total++] = x;
         }
     }
-    return 0;
+
+    return total <= text_len_max ? total : -total;
 }
 
 // trim whitespace from the beginning and end of a string
@@ -18368,7 +21425,10 @@ static int32_t llama_chat_apply_template_internal(
     std::string & dest, bool add_ass) {
     // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
     std::stringstream ss;
-    if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) {
+    auto tmpl_contains = [&tmpl](std::string haystack) -> bool {
+        return tmpl.find(haystack) != std::string::npos;
+    };
+    if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) {
         // chatml template
         for (auto message : chat) {
             ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
@@ -18376,16 +21436,16 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|im_start|>assistant\n";
         }
-    } else if (tmpl == "llama2" || tmpl.find("[INST]") != std::string::npos) {
+    } else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) {
         // llama2 template and its variants
         // [variant] support system message
-        bool support_system_message = tmpl.find("<>") != std::string::npos;
+        bool support_system_message = tmpl_contains("<>") || tmpl == "mistral";
         // [variant] space before + after response
-        bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
+        bool space_around_response = tmpl_contains("' ' + eos_token");
         // [variant] add BOS inside history
-        bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos;
+        bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
         // [variant] trim spaces from the input message
-        bool strip_message = tmpl.find("content.strip()") != std::string::npos;
+        bool strip_message = tmpl_contains("content.strip()");
         // construct the prompt
         bool is_inside_turn = true; // skip BOS at the beginning
         ss << "[INST] ";
@@ -18411,7 +21471,7 @@ static int32_t llama_chat_apply_template_internal(
             }
         }
         // llama2 templates seem to not care about "add_generation_prompt"
-    } else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("<|end|>") != std::string::npos)) {
+    } else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) {
         // Phi 3
         for (auto message : chat) {
             std::string role(message->role);
@@ -18420,7 +21480,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|assistant|>\n";
         }
-    } else if (tmpl == "zephyr" || tmpl.find("<|user|>") != std::string::npos) {
+    } else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) {
         // zephyr template
         for (auto message : chat) {
             ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
@@ -18428,7 +21488,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|assistant|>\n";
         }
-    } else if (tmpl == "monarch" || tmpl.find("bos_token + message['role']") != std::string::npos) {
+    } else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) {
         // mlabonne/AlphaMonarch-7B template (the  is included inside history)
         for (auto message : chat) {
             std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message
@@ -18437,7 +21497,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "assistant\n";
         }
-    } else if (tmpl == "gemma" || tmpl.find("") != std::string::npos) {
+    } else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl_contains("")) {
         // google/gemma-7b-it
         std::string system_prompt = "";
         for (auto message : chat) {
@@ -18459,7 +21519,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "model\n";
         }
-    } else if (tmpl == "orion" || tmpl.find("'\\n\\nAssistant: ' + eos_token") != std::string::npos) {
+    } else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
         // OrionStarAI/Orion-14B-Chat
         std::string system_prompt = "";
         for (auto message : chat) {
@@ -18479,7 +21539,7 @@ static int32_t llama_chat_apply_template_internal(
                 ss << message->content << "";
             }
         }
-    } else if (tmpl == "openchat" || tmpl.find("GPT4 Correct ") != std::string::npos) {
+    } else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) {
         // openchat/openchat-3.5-0106,
         for (auto message : chat) {
             std::string role(message->role);
@@ -18493,13 +21553,13 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "GPT4 Correct Assistant:";
         }
-    } else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl.find("USER: ") != std::string::npos && tmpl.find("ASSISTANT: ") != std::string::npos)) {
+    } else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) {
         // eachadea/vicuna-13b-1.1 (and Orca variant)
         for (auto message : chat) {
             std::string role(message->role);
             if (role == "system") {
                 // Orca-Vicuna variant uses a system prefix
-                if (tmpl == "vicuna-orca" || tmpl.find("SYSTEM: ") != std::string::npos) {
+                if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) {
                     ss << "SYSTEM: " << message->content << "\n";
                 } else {
                     ss << message->content << "\n\n";
@@ -18513,7 +21573,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "ASSISTANT:";
         }
-    } else if (tmpl == "deepseek" || (tmpl.find("### Instruction:") != std::string::npos && tmpl.find("<|EOT|>") != std::string::npos)) {
+    } else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) {
         // deepseek-ai/deepseek-coder-33b-instruct
         for (auto message : chat) {
             std::string role(message->role);
@@ -18528,7 +21588,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "### Response:\n";
         }
-    } else if (tmpl == "command-r" || (tmpl.find("<|START_OF_TURN_TOKEN|>") != std::string::npos && tmpl.find("<|USER_TOKEN|>") != std::string::npos)) {
+    } else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) {
         // CohereForAI/c4ai-command-r-plus
         for (auto message : chat) {
             std::string role(message->role);
@@ -18543,7 +21603,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
         }
-    } else if (tmpl == "llama3" || (tmpl.find("<|start_header_id|>") != std::string::npos && tmpl.find("<|end_header_id|>") != std::string::npos)) {
+    } else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) {
         // Llama 3
         for (auto message : chat) {
             std::string role(message->role);
@@ -18552,6 +21612,52 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
         }
+    } else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) {
+        // chatglm3-6b
+        ss << "[gMASK]" << "sop";
+        for (auto message : chat) {
+            std::string role(message->role);
+            ss << "<|" << role << "|>" << "\n " << message->content;
+        }
+        if (add_ass) {
+            ss << "<|assistant|>";
+        }
+    } else if (tmpl == "chatglm4" || tmpl_contains("[gMASK]")) {
+        ss << "[gMASK]" << "";
+        for (auto message : chat) {
+            std::string role(message->role);
+            ss << "<|" << role << "|>" << "\n" << message->content;
+        }
+        if (add_ass) {
+            ss << "<|assistant|>";
+        }
+    } else if (tmpl == "minicpm" || tmpl_contains(LU8("<用户>"))) {
+        // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "user") {
+                ss << LU8("<用户>");
+                ss << trim(message->content);
+                ss << "";
+            } else {
+                ss << trim(message->content);
+            }
+        }
+    } else if (tmpl == "deepseek2" || tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
+        // DeepSeek-V2
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << message->content << "\n\n";
+            } else if (role == "user") {
+                ss << "User: " << message->content << "\n\n";
+            } else if (role == "assistant") {
+                ss << "Assistant: " << message->content << LU8("<|end▁of▁sentence|>");
+            }
+        }
+        if (add_ass) {
+            ss << "Assistant:";
+        }
     } else {
         // template not supported
         return -1;
@@ -18737,6 +21843,8 @@ void llama_log_set(lm_ggml_log_callback log_callback, void * user_data) {
     lm_ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
 #elif defined(LM_GGML_USE_CUDA)
     lm_ggml_backend_cuda_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
+#elif defined(LM_GGML_USE_CANN)
+    lm_ggml_backend_cann_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
 #endif
 }
 
diff --git a/cpp/llama.h b/cpp/llama.h
index 565cce03..54cd2ed5 100644
--- a/cpp/llama.h
+++ b/cpp/llama.h
@@ -40,7 +40,7 @@
 #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
 
 #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN
-#define LLAMA_SESSION_VERSION 6
+#define LLAMA_SESSION_VERSION 7
 
 #define LLAMA_STATE_SEQ_MAGIC   LLAMA_FILE_MAGIC_GGSQ
 #define LLAMA_STATE_SEQ_VERSION 1
@@ -67,6 +67,7 @@ extern "C" {
         LLAMA_VOCAB_TYPE_SPM  = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
         LLAMA_VOCAB_TYPE_BPE  = 2, // GPT-2 tokenizer based on byte-level BPE
         LLAMA_VOCAB_TYPE_WPM  = 3, // BERT tokenizer based on WordPiece
+        LLAMA_VOCAB_TYPE_UGM  = 4, // T5 tokenizer based on Unigram
     };
 
     // pre-tokenization types
@@ -86,6 +87,11 @@ extern "C" {
         LLAMA_VOCAB_PRE_TYPE_OLMO           = 12,
         LLAMA_VOCAB_PRE_TYPE_DBRX           = 13,
         LLAMA_VOCAB_PRE_TYPE_SMAUG          = 14,
+        LLAMA_VOCAB_PRE_TYPE_PORO           = 15,
+        LLAMA_VOCAB_PRE_TYPE_CHATGLM3       = 16,
+        LLAMA_VOCAB_PRE_TYPE_CHATGLM4       = 17,
+        LLAMA_VOCAB_PRE_TYPE_VIKING         = 18,
+        LLAMA_VOCAB_PRE_TYPE_JAIS           = 19,
     };
 
     // note: these values should be synchronized with lm_ggml_rope
@@ -97,7 +103,7 @@ extern "C" {
         LLAMA_ROPE_TYPE_GLM  =  4,
     };
 
-    enum llama_token_type {
+    enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file
         LLAMA_TOKEN_TYPE_UNDEFINED    = 0,
         LLAMA_TOKEN_TYPE_NORMAL       = 1,
         LLAMA_TOKEN_TYPE_UNKNOWN      = 2,
@@ -107,13 +113,27 @@ extern "C" {
         LLAMA_TOKEN_TYPE_BYTE         = 6,
     };
 
+    enum llama_token_attr {
+        LLAMA_TOKEN_ATTR_UNDEFINED    = 0,
+        LLAMA_TOKEN_ATTR_UNKNOWN      = 1 << 0,
+        LLAMA_TOKEN_ATTR_UNUSED       = 1 << 1,
+        LLAMA_TOKEN_ATTR_NORMAL       = 1 << 2,
+        LLAMA_TOKEN_ATTR_CONTROL      = 1 << 3,  // SPECIAL?
+        LLAMA_TOKEN_ATTR_USER_DEFINED = 1 << 4,
+        LLAMA_TOKEN_ATTR_BYTE         = 1 << 5,
+        LLAMA_TOKEN_ATTR_NORMALIZED   = 1 << 6,
+        LLAMA_TOKEN_ATTR_LSTRIP       = 1 << 7,
+        LLAMA_TOKEN_ATTR_RSTRIP       = 1 << 8,
+        LLAMA_TOKEN_ATTR_SINGLE_WORD  = 1 << 9,
+    };
+
     // model file types
     enum llama_ftype {
         LLAMA_FTYPE_ALL_F32              = 0,
         LLAMA_FTYPE_MOSTLY_F16           = 1,  // except 1d tensors
         LLAMA_FTYPE_MOSTLY_Q4_0          = 2,  // except 1d tensors
         LLAMA_FTYPE_MOSTLY_Q4_1          = 3,  // except 1d tensors
-        LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4,  // tok_embeddings.weight and output.weight are F16
+        // LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4,  // tok_embeddings.weight and output.weight are F16
         // LLAMA_FTYPE_MOSTLY_Q4_2       = 5,  // support has been removed
         // LLAMA_FTYPE_MOSTLY_Q4_3       = 6,  // support has been removed
         LLAMA_FTYPE_MOSTLY_Q8_0          = 7,  // except 1d tensors
@@ -142,6 +162,9 @@ extern "C" {
         LLAMA_FTYPE_MOSTLY_IQ4_XS        = 30, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_IQ1_M         = 31, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_BF16          = 32, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_Q4_0_4_4      = 33, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_Q4_0_4_8      = 34, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_Q4_0_8_8      = 35, // except 1d tensors
 
         LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
     };
@@ -159,6 +182,13 @@ extern "C" {
         LLAMA_POOLING_TYPE_NONE = 0,
         LLAMA_POOLING_TYPE_MEAN = 1,
         LLAMA_POOLING_TYPE_CLS  = 2,
+        LLAMA_POOLING_TYPE_LAST = 3,
+    };
+
+    enum llama_attention_type {
+        LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1,
+        LLAMA_ATTENTION_TYPE_CAUSAL      = 0,
+        LLAMA_ATTENTION_TYPE_NON_CAUSAL  = 1,
     };
 
     enum llama_split_mode {
@@ -278,7 +308,7 @@ extern "C" {
 
         enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
         enum llama_pooling_type      pooling_type;      // whether to pool (sum) embedding results by sequence id
-                                                        // (ignored if no pooling layer)
+        enum llama_attention_type    attention_type;    // attention type to use for embeddings
 
         // ref: https://github.com/ggerganov/llama.cpp/pull/2054
         float    rope_freq_base;   // RoPE base frequency, 0 = from model
@@ -351,6 +381,9 @@ extern "C" {
         // modifies a preceding LLAMA_GRETYPE_CHAR or
         // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
         LLAMA_GRETYPE_CHAR_ALT       = 6,
+
+        // any character (.)
+        LLAMA_GRETYPE_CHAR_ANY       = 7,
     };
 
     typedef struct llama_grammar_element {
@@ -378,6 +411,9 @@ extern "C" {
         const char * content;
     } llama_chat_message;
 
+    // lora adapter
+    struct llama_lora_adapter;
+
     // Helpers for getting default parameters
     LLAMA_API struct llama_model_params llama_model_default_params(void);
     LLAMA_API struct llama_context_params llama_context_default_params(void);
@@ -424,8 +460,8 @@ extern "C" {
 
     LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
 
-    LLAMA_API enum llama_vocab_type   llama_vocab_type  (const struct llama_model   * model);
-    LLAMA_API enum llama_rope_type    llama_rope_type   (const struct llama_model   * model);
+    LLAMA_API enum llama_vocab_type   llama_vocab_type  (const struct llama_model * model);
+    LLAMA_API enum llama_rope_type    llama_rope_type   (const struct llama_model * model);
 
     LLAMA_API int32_t llama_n_vocab    (const struct llama_model * model);
     LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
@@ -464,24 +500,41 @@ extern "C" {
     // Get a llama model tensor
     LLAMA_API struct lm_ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
 
+    // Returns true if the model contains an encoder that requires llama_encode() call
+    LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
+
+    // For encoder-decoder models, this function returns id of the token that must be provided
+    // to the decoder to start generating output sequence. For other models, it returns -1.
+    LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
+
     // Returns 0 on success
     LLAMA_API uint32_t llama_model_quantize(
             const char * fname_inp,
             const char * fname_out,
             const llama_model_quantize_params * params);
 
-    // Apply a LoRA adapter to a loaded model
-    // path_base_model is the path to a higher quality model to use as a base for
-    // the layers modified by the adapter. Can be NULL to use the current loaded model.
-    // The model needs to be reloaded before applying a new adapter, otherwise the adapter
-    // will be applied on top of the previous one
-    // Returns 0 on success
-    LLAMA_API int32_t llama_model_apply_lora_from_file(
-            const struct llama_model * model,
-                          const char * path_lora,
-                               float   scale,
-                          const char * path_base_model,
-                             int32_t   n_threads);
+    // Load a LoRA adapter from file
+    // The loaded adapter will be associated to the given model, and will be free when the model is deleted
+    LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
+            struct llama_model * model,
+            const char * path_lora);
+
+    // Add a loaded LoRA adapter to given context
+    // This will not modify model's weight
+    LLAMA_API int32_t llama_lora_adapter_set(
+            struct llama_context * ctx,
+            struct llama_lora_adapter * adapter,
+            float scale);
+
+    // Remove a LoRA adapter from given context
+    // Return -1 if the adapter is not present in the context
+    LLAMA_API int32_t llama_lora_adapter_remove(
+            struct llama_context * ctx,
+            struct llama_lora_adapter * adapter);
+
+    // Manually free a LoRA adapter
+    // Note: loaded adapters will be free when the associated model is deleted
+    LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
 
     // Apply a loaded control vector to a llama_context, or if data is NULL, clear
     // the currently loaded vector.
@@ -749,6 +802,14 @@ extern "C" {
     // Frees a batch of tokens allocated with llama_batch_init()
     LLAMA_API void llama_batch_free(struct llama_batch batch);
 
+    // Processes a batch of tokens with the ecoder part of the encoder-decoder model.
+    // Stores the encoder output internally for later use by the decoder cross-attention layers.
+    //   0 - success
+    // < 0 - error
+    LLAMA_API int32_t llama_encode(
+            struct llama_context * ctx,
+              struct llama_batch   batch);
+
     // Positive return values does not mean a fatal error, but rather a warning.
     //   0 - success
     //   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
@@ -768,6 +829,10 @@ extern "C" {
     // Get the number of threads used for prompt and batch processing (multiple token).
     LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
 
+    // Set whether the model is in embeddings mode or not
+    // If true, embeddings will be returned but logits will not
+    LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
+
     // Set whether to use causal attention or not
     // If set to true, the model will only attend to the past tokens
     LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
@@ -821,7 +886,7 @@ extern "C" {
 
     LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
 
-    LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
+    LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token);
 
     // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
     LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
@@ -835,6 +900,7 @@ extern "C" {
     LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
     LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
     LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
+    LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding
 
     // Returns -1 if unknown, 1 for true or 0 for false.
     LLAMA_API int32_t         llama_add_bos_token(const struct llama_model * model);
@@ -856,6 +922,7 @@ extern "C" {
     /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
     /// @return Returns the number of tokens on success, no more than n_tokens_max
     /// @return Returns a negative number on failure - the number of tokens that would have been returned
+    /// @param add_special Allow to add BOS and EOS tokens if model is configured to do so.
     /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
     ///                      as plaintext. Does not insert a leading space.
     LLAMA_API int32_t llama_tokenize(
@@ -870,15 +937,31 @@ extern "C" {
     // Token Id -> Piece.
     // Uses the vocabulary in the provided context.
     // Does not write null terminator to the buffer.
-    // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
+    // User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix')
     // @param special If true, special tokens are rendered in the output.
     LLAMA_API int32_t llama_token_to_piece(
               const struct llama_model * model,
                            llama_token   token,
                                   char * buf,
                                int32_t   length,
+                               int32_t   lstrip,
                                   bool   special);
 
+    /// @details Convert the provided tokens into text (inverse of llama_tokenize()).
+    /// @param text The char pointer must be large enough to hold the resulting text.
+    /// @return Returns the number of chars/bytes on success, no more than text_len_max.
+    /// @return Returns a negative number on failure - the number of chars/bytes that would have been returned.
+    /// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so.
+    /// @param unparse_special If true, special tokens are rendered in the output.
+    LLAMA_API int32_t llama_detokenize(
+        const struct llama_model * model,
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special);
+
     /// Apply chat template. Inspired by hf apply_chat_template() on python.
     /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
     /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
@@ -902,6 +985,12 @@ extern "C" {
     // Grammar
     //
 
+    /// Initialize a llama_grammar.
+    ///
+    /// @param rules The rule elements of the grammar to initialize.
+    /// @param n_rules The number of rules.
+    /// @param start_rule_index The index of the root rule (the starting point of the grammar).
+    /// @return The initialized llama_grammar or nullptr if initialization failed.
     LLAMA_API struct llama_grammar * llama_grammar_init(
             const llama_grammar_element ** rules,
                                  size_t    n_rules,
@@ -1042,49 +1131,9 @@ extern "C" {
                      llama_token   token);
 
     //
-    // Beam search
+    // Model split
     //
 
-    struct llama_beam_view {
-        const llama_token * tokens;
-
-        size_t n_tokens;
-        float  p;        // Cumulative beam probability (renormalized relative to all beams)
-        bool   eob;      // Callback should set this to true when a beam is at end-of-beam.
-    };
-
-    // Passed to beam_search_callback function.
-    // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
-    // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
-    // These pointers are valid only during the synchronous callback, so should not be saved.
-    struct llama_beams_state {
-        struct llama_beam_view * beam_views;
-
-        size_t n_beams;               // Number of elements in beam_views[].
-        size_t common_prefix_length;  // Current max length of prefix tokens shared by all beams.
-        bool   last_call;             // True iff this is the last callback invocation.
-    };
-
-    // Type of pointer to the beam_search_callback function.
-    // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
-    // passed back to beam_search_callback. This avoids having to use global variables in the callback.
-    typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
-
-    /// @details Deterministically returns entire sentence constructed by a beam search.
-    /// @param ctx Pointer to the llama_context.
-    /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
-    /// @param callback_data A pointer that is simply passed back to callback.
-    /// @param n_beams Number of beams to use.
-    /// @param n_past Number of tokens already evaluated.
-    /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
-    LLAMA_API void llama_beam_search(
-                   struct llama_context * ctx,
-        llama_beam_search_callback_fn_t   callback,
-                                   void * callback_data,
-                                 size_t   n_beams,
-                                int32_t   n_past,
-                                int32_t   n_predict);
-
     /// @details Build a split GGUF final path for this chunk.
     ///          llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf"
     //  Returns the split_path length.
diff --git a/cpp/log.h b/cpp/log.h
index 2cd0b543..daad7e43 100644
--- a/cpp/log.h
+++ b/cpp/log.h
@@ -643,7 +643,7 @@ inline std::string LOG_TOKENS_TOSTR_PRETTY(const C & ctx, const T & tokens)
     buf << "[ ";
 
     bool first = true;
-    for (const auto &token : tokens)
+    for (const auto & token : tokens)
     {
         if (!first) {
             buf << ", ";
diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp
index c9417094..a83676f7 100644
--- a/cpp/rn-llama.hpp
+++ b/cpp/rn-llama.hpp
@@ -253,9 +253,17 @@ struct llama_rn_context
 
     void loadPrompt()
     {
-        std::vector prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
+        std::vector prompt_tokens = ::llama_tokenize(ctx, params.prompt, true, true);
         num_prompt_tokens = prompt_tokens.size();
 
+        // LOG tokens
+        std::stringstream ss;
+        ss << "\n" << __func__ << ": prompt_tokens = ";
+        for (auto& token : prompt_tokens) {
+            ss << token << " ";
+        }
+        LOG_INFO("%s\n", ss.str().c_str());
+
         if (params.n_keep < 0)
         {
             params.n_keep = (int)num_prompt_tokens;
diff --git a/cpp/sampling.cpp b/cpp/sampling.cpp
index d591e52f..794aa433 100644
--- a/cpp/sampling.cpp
+++ b/cpp/sampling.cpp
@@ -28,9 +28,13 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
 
         std::vector grammar_rules(result->parsed_grammar.c_rules());
 
-        result->grammar = llama_grammar_init(
+        struct llama_grammar * grammar = llama_grammar_init(
                 grammar_rules.data(),
                 grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
+        if (grammar == nullptr) {
+            throw std::runtime_error("Failed to initialize llama_grammar");
+        }
+        result->grammar = grammar;
     }
 
     result->prev.resize(params.n_prev);
@@ -59,9 +63,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
     if (!ctx->parsed_grammar.rules.empty()) {
         std::vector grammar_rules(ctx->parsed_grammar.c_rules());
 
-        ctx->grammar = llama_grammar_init(
+        struct llama_grammar * grammar = llama_grammar_init(
                 grammar_rules.data(),
                 grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
+        if (grammar == nullptr) {
+            throw std::runtime_error("Failed to initialize llama_grammar");
+        }
+        ctx->grammar = grammar;
     }
 
     std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
@@ -274,8 +282,6 @@ static llama_token llama_sampling_sample_impl(
         LM_GGML_ASSERT(!original_logits.empty());
     }
     llama_token id = 0;
-    // Get a pointer to the logits
-    float * logits = llama_get_logits_ith(ctx_main, idx);
 
     if (temp < 0.0) {
         // greedy sampling, with probs
@@ -316,6 +322,9 @@ static llama_token llama_sampling_sample_impl(
     }
 
     if (ctx_sampling->grammar != NULL && !is_resampling) {
+        // Get a pointer to the logits
+        float * logits = llama_get_logits_ith(ctx_main, idx);
+
         // Create an array with a single token data element for the sampled id
         llama_token_data single_token_data = {id, logits[id], 0.0f};
         llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
@@ -369,7 +378,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
     if (ctx_sampling->grammar != NULL && !apply_grammar) {
         LM_GGML_ASSERT(original_logits != NULL);
         // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
-        *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
+        *original_logits = {logits, logits + n_vocab};
     }
 
     // apply params.logit_bias map
@@ -382,10 +391,10 @@ static llama_token_data_array llama_sampling_prepare_impl(
         llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
     }
 
-    cur.clear();
+    cur.resize(n_vocab);
 
     for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-        cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+        cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
     }
 
     llama_token_data_array cur_p = { cur.data(), cur.size(), false };
diff --git a/cpp/sgemm.cpp b/cpp/sgemm.cpp
index 16945992..0205fd96 100644
--- a/cpp/sgemm.cpp
+++ b/cpp/sgemm.cpp
@@ -43,8 +43,10 @@
 // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
 //     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
 
+#if defined(__GNUC__)
 #pragma GCC diagnostic ignored "-Wpedantic"
 #pragma GCC diagnostic ignored "-Wignored-attributes"
+#endif
 
 #include "sgemm.h"
 #include "ggml-impl.h"
@@ -247,9 +249,8 @@ class tinyBLAS {
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
     }
 
-    void matmul(int64_t m, int64_t n, int task) {
-        if (task == LM_GGML_TASK_TYPE_COMPUTE)
-            mnpack(0, m, 0, n);
+    void matmul(int64_t m, int64_t n) {
+        mnpack(0, m, 0, n);
     }
 
   private:
@@ -456,9 +457,8 @@ class tinyBLAS_Q0_ARM {
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
     }
 
-    void matmul(int64_t m, int64_t n, int task) {
-        if (task == LM_GGML_TASK_TYPE_COMPUTE)
-            mnpack(0, m, 0, n);
+    void matmul(int64_t m, int64_t n) {
+        mnpack(0, m, 0, n);
     }
 
   private:
@@ -594,9 +594,8 @@ class tinyBLAS_Q0_AVX {
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
     }
 
-    void matmul(int64_t m, int64_t n, int task) {
-        if (task == LM_GGML_TASK_TYPE_COMPUTE)
-            mnpack(0, m, 0, n);
+    void matmul(int64_t m, int64_t n) {
+        mnpack(0, m, 0, n);
     }
 
   private:
@@ -827,7 +826,7 @@ class tinyBLAS_Q0_AVX {
  * For example, for single-threaded single-precision GEMM you can say
  *
  *     llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
- *                     0, 1, LM_GGML_TASK_TYPE_COMPUTE,
+ *                     0, 1,
  *                     LM_GGML_TYPE_F32, LM_GGML_TYPE_F32, LM_GGML_TYPE_F32);
  *
  * @param m is rows in `A` and `C`
@@ -841,14 +840,13 @@ class tinyBLAS_Q0_AVX {
  * @param ldc is row stride of `C`
  * @param ith is thread id (must be less than `nth`)
  * @param nth is number of threads (must be greater than zero)
- * @param task is GGML task type
  * @param Atype is GGML data type of `A`
  * @param Btype is GGML data type of `B`
  * @param Ctype is GGML data type of `C`
  * @return true if this function was able to service the matmul request
  */
 bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
-                     int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
+                     int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
 
     assert(m >= 0);
     assert(n >= 0);
@@ -875,7 +873,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__AVX__) || defined(__AVX2__)
         if (k % 8)
@@ -885,7 +883,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__ARM_NEON)
         if (n < 4)
@@ -897,7 +895,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #else
         return false;
@@ -915,7 +913,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
         if (k % 8)
@@ -927,7 +925,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
         if (n < 8)
@@ -941,7 +939,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const lm_ggml_fp16_t *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__ARM_NEON) && !defined(_MSC_VER)
         if (k % 4)
@@ -953,7 +951,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #else
         return false;
@@ -969,7 +967,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const block_q8_0 *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__ARM_FEATURE_DOTPROD)
         tinyBLAS_Q0_ARM tb{
@@ -977,7 +975,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const block_q8_0 *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #else
         return false;
@@ -993,7 +991,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const block_q8_0 *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__ARM_FEATURE_DOTPROD)
         tinyBLAS_Q0_ARM tb{
@@ -1001,7 +999,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const block_q8_0 *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #else
         return false;
@@ -1023,7 +1021,6 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
     (void)ldc;
     (void)ith;
     (void)nth;
-    (void)task;
     (void)Atype;
     (void)Btype;
     (void)Ctype;
diff --git a/cpp/sgemm.h b/cpp/sgemm.h
index f29747d0..caf6dd55 100644
--- a/cpp/sgemm.h
+++ b/cpp/sgemm.h
@@ -7,7 +7,7 @@ extern "C" {
 
 bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
                      const void *, int64_t, void *, int64_t, int, int,
-                     int, int, int, int);
+                     int, int, int);
 
 #ifdef __cplusplus
 }
diff --git a/cpp/unicode-data.cpp b/cpp/unicode-data.cpp
index d7c1c898..02bdf782 100644
--- a/cpp/unicode-data.cpp
+++ b/cpp/unicode-data.cpp
@@ -68,36 +68,36 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x000370, 0x0004},
 {0x000375, 0x0040},
 {0x000376, 0x0004},
-{0x000378, 0x0080},
+{0x000378, 0x0001},
 {0x00037A, 0x0004},
 {0x00037E, 0x0020},
 {0x00037F, 0x0004},
-{0x000380, 0x0080},
+{0x000380, 0x0001},
 {0x000384, 0x0040},
 {0x000386, 0x0004},
 {0x000387, 0x0020},
 {0x000388, 0x0004},
-{0x00038B, 0x0080},
+{0x00038B, 0x0001},
 {0x00038C, 0x0004},
-{0x00038D, 0x0080},
+{0x00038D, 0x0001},
 {0x00038E, 0x0004},
-{0x0003A2, 0x0080},
+{0x0003A2, 0x0001},
 {0x0003A3, 0x0004},
 {0x0003F6, 0x0040},
 {0x0003F7, 0x0004},
 {0x000482, 0x0040},
 {0x000483, 0x0010},
 {0x00048A, 0x0004},
-{0x000530, 0x0080},
+{0x000530, 0x0001},
 {0x000531, 0x0004},
-{0x000557, 0x0080},
+{0x000557, 0x0001},
 {0x000559, 0x0004},
 {0x00055A, 0x0020},
 {0x000560, 0x0004},
 {0x000589, 0x0020},
-{0x00058B, 0x0080},
+{0x00058B, 0x0001},
 {0x00058D, 0x0040},
-{0x000590, 0x0080},
+{0x000590, 0x0001},
 {0x000591, 0x0010},
 {0x0005BE, 0x0020},
 {0x0005BF, 0x0010},
@@ -107,12 +107,13 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0005C4, 0x0010},
 {0x0005C6, 0x0020},
 {0x0005C7, 0x0010},
-{0x0005C8, 0x0080},
+{0x0005C8, 0x0001},
 {0x0005D0, 0x0004},
-{0x0005EB, 0x0080},
+{0x0005EB, 0x0001},
 {0x0005EF, 0x0004},
 {0x0005F3, 0x0020},
-{0x0005F5, 0x0080},
+{0x0005F5, 0x0001},
+{0x000600, 0x0080},
 {0x000606, 0x0040},
 {0x000609, 0x0020},
 {0x00060B, 0x0040},
@@ -145,16 +146,17 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0006FD, 0x0040},
 {0x0006FF, 0x0004},
 {0x000700, 0x0020},
-{0x00070E, 0x0080},
+{0x00070E, 0x0001},
+{0x00070F, 0x0080},
 {0x000710, 0x0004},
 {0x000711, 0x0010},
 {0x000712, 0x0004},
 {0x000730, 0x0010},
-{0x00074B, 0x0080},
+{0x00074B, 0x0001},
 {0x00074D, 0x0004},
 {0x0007A6, 0x0010},
 {0x0007B1, 0x0004},
-{0x0007B2, 0x0080},
+{0x0007B2, 0x0001},
 {0x0007C0, 0x0002},
 {0x0007CA, 0x0004},
 {0x0007EB, 0x0010},
@@ -162,7 +164,7 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0007F6, 0x0040},
 {0x0007F7, 0x0020},
 {0x0007FA, 0x0004},
-{0x0007FB, 0x0080},
+{0x0007FB, 0x0001},
 {0x0007FD, 0x0010},
 {0x0007FE, 0x0040},
 {0x000800, 0x0004},
@@ -173,20 +175,22 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x000825, 0x0010},
 {0x000828, 0x0004},
 {0x000829, 0x0010},
-{0x00082E, 0x0080},
+{0x00082E, 0x0001},
 {0x000830, 0x0020},
-{0x00083F, 0x0080},
+{0x00083F, 0x0001},
 {0x000840, 0x0004},
 {0x000859, 0x0010},
-{0x00085C, 0x0080},
+{0x00085C, 0x0001},
 {0x00085E, 0x0020},
-{0x00085F, 0x0080},
+{0x00085F, 0x0001},
 {0x000860, 0x0004},
-{0x00086B, 0x0080},
+{0x00086B, 0x0001},
 {0x000870, 0x0004},
 {0x000888, 0x0040},
 {0x000889, 0x0004},
-{0x00088F, 0x0080},
+{0x00088F, 0x0001},
+{0x000890, 0x0080},
+{0x000892, 0x0001},
 {0x000898, 0x0010},
 {0x0008A0, 0x0004},
 {0x0008CA, 0x0010},
@@ -205,35 +209,35 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x000970, 0x0020},
 {0x000971, 0x0004},
 {0x000981, 0x0010},
-{0x000984, 0x0080},
+{0x000984, 0x0001},
 {0x000985, 0x0004},
-{0x00098D, 0x0080},
+{0x00098D, 0x0001},
 {0x00098F, 0x0004},
-{0x000991, 0x0080},
+{0x000991, 0x0001},
 {0x000993, 0x0004},
-{0x0009A9, 0x0080},
+{0x0009A9, 0x0001},
 {0x0009AA, 0x0004},
-{0x0009B1, 0x0080},
+{0x0009B1, 0x0001},
 {0x0009B2, 0x0004},
-{0x0009B3, 0x0080},
+{0x0009B3, 0x0001},
 {0x0009B6, 0x0004},
-{0x0009BA, 0x0080},
+{0x0009BA, 0x0001},
 {0x0009BC, 0x0010},
 {0x0009BD, 0x0004},
 {0x0009BE, 0x0010},
-{0x0009C5, 0x0080},
+{0x0009C5, 0x0001},
 {0x0009C7, 0x0010},
-{0x0009C9, 0x0080},
+{0x0009C9, 0x0001},
 {0x0009CB, 0x0010},
 {0x0009CE, 0x0004},
-{0x0009CF, 0x0080},
+{0x0009CF, 0x0001},
 {0x0009D7, 0x0010},
-{0x0009D8, 0x0080},
+{0x0009D8, 0x0001},
 {0x0009DC, 0x0004},
-{0x0009DE, 0x0080},
+{0x0009DE, 0x0001},
 {0x0009DF, 0x0004},
 {0x0009E2, 0x0010},
-{0x0009E4, 0x0080},
+{0x0009E4, 0x0001},
 {0x0009E6, 0x0002},
 {0x0009F0, 0x0004},
 {0x0009F2, 0x0040},
@@ -242,173 +246,173 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0009FC, 0x0004},
 {0x0009FD, 0x0020},
 {0x0009FE, 0x0010},
-{0x0009FF, 0x0080},
+{0x0009FF, 0x0001},
 {0x000A01, 0x0010},
-{0x000A04, 0x0080},
+{0x000A04, 0x0001},
 {0x000A05, 0x0004},
-{0x000A0B, 0x0080},
+{0x000A0B, 0x0001},
 {0x000A0F, 0x0004},
-{0x000A11, 0x0080},
+{0x000A11, 0x0001},
 {0x000A13, 0x0004},
-{0x000A29, 0x0080},
+{0x000A29, 0x0001},
 {0x000A2A, 0x0004},
-{0x000A31, 0x0080},
+{0x000A31, 0x0001},
 {0x000A32, 0x0004},
-{0x000A34, 0x0080},
+{0x000A34, 0x0001},
 {0x000A35, 0x0004},
-{0x000A37, 0x0080},
+{0x000A37, 0x0001},
 {0x000A38, 0x0004},
-{0x000A3A, 0x0080},
+{0x000A3A, 0x0001},
 {0x000A3C, 0x0010},
-{0x000A3D, 0x0080},
+{0x000A3D, 0x0001},
 {0x000A3E, 0x0010},
-{0x000A43, 0x0080},
+{0x000A43, 0x0001},
 {0x000A47, 0x0010},
-{0x000A49, 0x0080},
+{0x000A49, 0x0001},
 {0x000A4B, 0x0010},
-{0x000A4E, 0x0080},
+{0x000A4E, 0x0001},
 {0x000A51, 0x0010},
-{0x000A52, 0x0080},
+{0x000A52, 0x0001},
 {0x000A59, 0x0004},
-{0x000A5D, 0x0080},
+{0x000A5D, 0x0001},
 {0x000A5E, 0x0004},
-{0x000A5F, 0x0080},
+{0x000A5F, 0x0001},
 {0x000A66, 0x0002},
 {0x000A70, 0x0010},
 {0x000A72, 0x0004},
 {0x000A75, 0x0010},
 {0x000A76, 0x0020},
-{0x000A77, 0x0080},
+{0x000A77, 0x0001},
 {0x000A81, 0x0010},
-{0x000A84, 0x0080},
+{0x000A84, 0x0001},
 {0x000A85, 0x0004},
-{0x000A8E, 0x0080},
+{0x000A8E, 0x0001},
 {0x000A8F, 0x0004},
-{0x000A92, 0x0080},
+{0x000A92, 0x0001},
 {0x000A93, 0x0004},
-{0x000AA9, 0x0080},
+{0x000AA9, 0x0001},
 {0x000AAA, 0x0004},
-{0x000AB1, 0x0080},
+{0x000AB1, 0x0001},
 {0x000AB2, 0x0004},
-{0x000AB4, 0x0080},
+{0x000AB4, 0x0001},
 {0x000AB5, 0x0004},
-{0x000ABA, 0x0080},
+{0x000ABA, 0x0001},
 {0x000ABC, 0x0010},
 {0x000ABD, 0x0004},
 {0x000ABE, 0x0010},
-{0x000AC6, 0x0080},
+{0x000AC6, 0x0001},
 {0x000AC7, 0x0010},
-{0x000ACA, 0x0080},
+{0x000ACA, 0x0001},
 {0x000ACB, 0x0010},
-{0x000ACE, 0x0080},
+{0x000ACE, 0x0001},
 {0x000AD0, 0x0004},
-{0x000AD1, 0x0080},
+{0x000AD1, 0x0001},
 {0x000AE0, 0x0004},
 {0x000AE2, 0x0010},
-{0x000AE4, 0x0080},
+{0x000AE4, 0x0001},
 {0x000AE6, 0x0002},
 {0x000AF0, 0x0020},
 {0x000AF1, 0x0040},
-{0x000AF2, 0x0080},
+{0x000AF2, 0x0001},
 {0x000AF9, 0x0004},
 {0x000AFA, 0x0010},
-{0x000B00, 0x0080},
+{0x000B00, 0x0001},
 {0x000B01, 0x0010},
-{0x000B04, 0x0080},
+{0x000B04, 0x0001},
 {0x000B05, 0x0004},
-{0x000B0D, 0x0080},
+{0x000B0D, 0x0001},
 {0x000B0F, 0x0004},
-{0x000B11, 0x0080},
+{0x000B11, 0x0001},
 {0x000B13, 0x0004},
-{0x000B29, 0x0080},
+{0x000B29, 0x0001},
 {0x000B2A, 0x0004},
-{0x000B31, 0x0080},
+{0x000B31, 0x0001},
 {0x000B32, 0x0004},
-{0x000B34, 0x0080},
+{0x000B34, 0x0001},
 {0x000B35, 0x0004},
-{0x000B3A, 0x0080},
+{0x000B3A, 0x0001},
 {0x000B3C, 0x0010},
 {0x000B3D, 0x0004},
 {0x000B3E, 0x0010},
-{0x000B45, 0x0080},
+{0x000B45, 0x0001},
 {0x000B47, 0x0010},
-{0x000B49, 0x0080},
+{0x000B49, 0x0001},
 {0x000B4B, 0x0010},
-{0x000B4E, 0x0080},
+{0x000B4E, 0x0001},
 {0x000B55, 0x0010},
-{0x000B58, 0x0080},
+{0x000B58, 0x0001},
 {0x000B5C, 0x0004},
-{0x000B5E, 0x0080},
+{0x000B5E, 0x0001},
 {0x000B5F, 0x0004},
 {0x000B62, 0x0010},
-{0x000B64, 0x0080},
+{0x000B64, 0x0001},
 {0x000B66, 0x0002},
 {0x000B70, 0x0040},
 {0x000B71, 0x0004},
 {0x000B72, 0x0002},
-{0x000B78, 0x0080},
+{0x000B78, 0x0001},
 {0x000B82, 0x0010},
 {0x000B83, 0x0004},
-{0x000B84, 0x0080},
+{0x000B84, 0x0001},
 {0x000B85, 0x0004},
-{0x000B8B, 0x0080},
+{0x000B8B, 0x0001},
 {0x000B8E, 0x0004},
-{0x000B91, 0x0080},
+{0x000B91, 0x0001},
 {0x000B92, 0x0004},
-{0x000B96, 0x0080},
+{0x000B96, 0x0001},
 {0x000B99, 0x0004},
-{0x000B9B, 0x0080},
+{0x000B9B, 0x0001},
 {0x000B9C, 0x0004},
-{0x000B9D, 0x0080},
+{0x000B9D, 0x0001},
 {0x000B9E, 0x0004},
-{0x000BA0, 0x0080},
+{0x000BA0, 0x0001},
 {0x000BA3, 0x0004},
-{0x000BA5, 0x0080},
+{0x000BA5, 0x0001},
 {0x000BA8, 0x0004},
-{0x000BAB, 0x0080},
+{0x000BAB, 0x0001},
 {0x000BAE, 0x0004},
-{0x000BBA, 0x0080},
+{0x000BBA, 0x0001},
 {0x000BBE, 0x0010},
-{0x000BC3, 0x0080},
+{0x000BC3, 0x0001},
 {0x000BC6, 0x0010},
-{0x000BC9, 0x0080},
+{0x000BC9, 0x0001},
 {0x000BCA, 0x0010},
-{0x000BCE, 0x0080},
+{0x000BCE, 0x0001},
 {0x000BD0, 0x0004},
-{0x000BD1, 0x0080},
+{0x000BD1, 0x0001},
 {0x000BD7, 0x0010},
-{0x000BD8, 0x0080},
+{0x000BD8, 0x0001},
 {0x000BE6, 0x0002},
 {0x000BF3, 0x0040},
-{0x000BFB, 0x0080},
+{0x000BFB, 0x0001},
 {0x000C00, 0x0010},
 {0x000C05, 0x0004},
-{0x000C0D, 0x0080},
+{0x000C0D, 0x0001},
 {0x000C0E, 0x0004},
-{0x000C11, 0x0080},
+{0x000C11, 0x0001},
 {0x000C12, 0x0004},
-{0x000C29, 0x0080},
+{0x000C29, 0x0001},
 {0x000C2A, 0x0004},
-{0x000C3A, 0x0080},
+{0x000C3A, 0x0001},
 {0x000C3C, 0x0010},
 {0x000C3D, 0x0004},
 {0x000C3E, 0x0010},
-{0x000C45, 0x0080},
+{0x000C45, 0x0001},
 {0x000C46, 0x0010},
-{0x000C49, 0x0080},
+{0x000C49, 0x0001},
 {0x000C4A, 0x0010},
-{0x000C4E, 0x0080},
+{0x000C4E, 0x0001},
 {0x000C55, 0x0010},
-{0x000C57, 0x0080},
+{0x000C57, 0x0001},
 {0x000C58, 0x0004},
-{0x000C5B, 0x0080},
+{0x000C5B, 0x0001},
 {0x000C5D, 0x0004},
-{0x000C5E, 0x0080},
+{0x000C5E, 0x0001},
 {0x000C60, 0x0004},
 {0x000C62, 0x0010},
-{0x000C64, 0x0080},
+{0x000C64, 0x0001},
 {0x000C66, 0x0002},
-{0x000C70, 0x0080},
+{0x000C70, 0x0001},
 {0x000C77, 0x0020},
 {0x000C78, 0x0002},
 {0x000C7F, 0x0040},
@@ -416,124 +420,124 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x000C81, 0x0010},
 {0x000C84, 0x0020},
 {0x000C85, 0x0004},
-{0x000C8D, 0x0080},
+{0x000C8D, 0x0001},
 {0x000C8E, 0x0004},
-{0x000C91, 0x0080},
+{0x000C91, 0x0001},
 {0x000C92, 0x0004},
-{0x000CA9, 0x0080},
+{0x000CA9, 0x0001},
 {0x000CAA, 0x0004},
-{0x000CB4, 0x0080},
+{0x000CB4, 0x0001},
 {0x000CB5, 0x0004},
-{0x000CBA, 0x0080},
+{0x000CBA, 0x0001},
 {0x000CBC, 0x0010},
 {0x000CBD, 0x0004},
 {0x000CBE, 0x0010},
-{0x000CC5, 0x0080},
+{0x000CC5, 0x0001},
 {0x000CC6, 0x0010},
-{0x000CC9, 0x0080},
+{0x000CC9, 0x0001},
 {0x000CCA, 0x0010},
-{0x000CCE, 0x0080},
+{0x000CCE, 0x0001},
 {0x000CD5, 0x0010},
-{0x000CD7, 0x0080},
+{0x000CD7, 0x0001},
 {0x000CDD, 0x0004},
-{0x000CDF, 0x0080},
+{0x000CDF, 0x0001},
 {0x000CE0, 0x0004},
 {0x000CE2, 0x0010},
-{0x000CE4, 0x0080},
+{0x000CE4, 0x0001},
 {0x000CE6, 0x0002},
-{0x000CF0, 0x0080},
+{0x000CF0, 0x0001},
 {0x000CF1, 0x0004},
 {0x000CF3, 0x0010},
-{0x000CF4, 0x0080},
+{0x000CF4, 0x0001},
 {0x000D00, 0x0010},
 {0x000D04, 0x0004},
-{0x000D0D, 0x0080},
+{0x000D0D, 0x0001},
 {0x000D0E, 0x0004},
-{0x000D11, 0x0080},
+{0x000D11, 0x0001},
 {0x000D12, 0x0004},
 {0x000D3B, 0x0010},
 {0x000D3D, 0x0004},
 {0x000D3E, 0x0010},
-{0x000D45, 0x0080},
+{0x000D45, 0x0001},
 {0x000D46, 0x0010},
-{0x000D49, 0x0080},
+{0x000D49, 0x0001},
 {0x000D4A, 0x0010},
 {0x000D4E, 0x0004},
 {0x000D4F, 0x0040},
-{0x000D50, 0x0080},
+{0x000D50, 0x0001},
 {0x000D54, 0x0004},
 {0x000D57, 0x0010},
 {0x000D58, 0x0002},
 {0x000D5F, 0x0004},
 {0x000D62, 0x0010},
-{0x000D64, 0x0080},
+{0x000D64, 0x0001},
 {0x000D66, 0x0002},
 {0x000D79, 0x0040},
 {0x000D7A, 0x0004},
-{0x000D80, 0x0080},
+{0x000D80, 0x0001},
 {0x000D81, 0x0010},
-{0x000D84, 0x0080},
+{0x000D84, 0x0001},
 {0x000D85, 0x0004},
-{0x000D97, 0x0080},
+{0x000D97, 0x0001},
 {0x000D9A, 0x0004},
-{0x000DB2, 0x0080},
+{0x000DB2, 0x0001},
 {0x000DB3, 0x0004},
-{0x000DBC, 0x0080},
+{0x000DBC, 0x0001},
 {0x000DBD, 0x0004},
-{0x000DBE, 0x0080},
+{0x000DBE, 0x0001},
 {0x000DC0, 0x0004},
-{0x000DC7, 0x0080},
+{0x000DC7, 0x0001},
 {0x000DCA, 0x0010},
-{0x000DCB, 0x0080},
+{0x000DCB, 0x0001},
 {0x000DCF, 0x0010},
-{0x000DD5, 0x0080},
+{0x000DD5, 0x0001},
 {0x000DD6, 0x0010},
-{0x000DD7, 0x0080},
+{0x000DD7, 0x0001},
 {0x000DD8, 0x0010},
-{0x000DE0, 0x0080},
+{0x000DE0, 0x0001},
 {0x000DE6, 0x0002},
-{0x000DF0, 0x0080},
+{0x000DF0, 0x0001},
 {0x000DF2, 0x0010},
 {0x000DF4, 0x0020},
-{0x000DF5, 0x0080},
+{0x000DF5, 0x0001},
 {0x000E01, 0x0004},
 {0x000E31, 0x0010},
 {0x000E32, 0x0004},
 {0x000E34, 0x0010},
-{0x000E3B, 0x0080},
+{0x000E3B, 0x0001},
 {0x000E3F, 0x0040},
 {0x000E40, 0x0004},
 {0x000E47, 0x0010},
 {0x000E4F, 0x0020},
 {0x000E50, 0x0002},
 {0x000E5A, 0x0020},
-{0x000E5C, 0x0080},
+{0x000E5C, 0x0001},
 {0x000E81, 0x0004},
-{0x000E83, 0x0080},
+{0x000E83, 0x0001},
 {0x000E84, 0x0004},
-{0x000E85, 0x0080},
+{0x000E85, 0x0001},
 {0x000E86, 0x0004},
-{0x000E8B, 0x0080},
+{0x000E8B, 0x0001},
 {0x000E8C, 0x0004},
-{0x000EA4, 0x0080},
+{0x000EA4, 0x0001},
 {0x000EA5, 0x0004},
-{0x000EA6, 0x0080},
+{0x000EA6, 0x0001},
 {0x000EA7, 0x0004},
 {0x000EB1, 0x0010},
 {0x000EB2, 0x0004},
 {0x000EB4, 0x0010},
 {0x000EBD, 0x0004},
-{0x000EBE, 0x0080},
+{0x000EBE, 0x0001},
 {0x000EC0, 0x0004},
-{0x000EC5, 0x0080},
+{0x000EC5, 0x0001},
 {0x000EC6, 0x0004},
-{0x000EC7, 0x0080},
+{0x000EC7, 0x0001},
 {0x000EC8, 0x0010},
-{0x000ECF, 0x0080},
+{0x000ECF, 0x0001},
 {0x000ED0, 0x0002},
-{0x000EDA, 0x0080},
+{0x000EDA, 0x0001},
 {0x000EDC, 0x0004},
-{0x000EE0, 0x0080},
+{0x000EE0, 0x0001},
 {0x000F00, 0x0004},
 {0x000F01, 0x0040},
 {0x000F04, 0x0020},
@@ -552,26 +556,26 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x000F3A, 0x0020},
 {0x000F3E, 0x0010},
 {0x000F40, 0x0004},
-{0x000F48, 0x0080},
+{0x000F48, 0x0001},
 {0x000F49, 0x0004},
-{0x000F6D, 0x0080},
+{0x000F6D, 0x0001},
 {0x000F71, 0x0010},
 {0x000F85, 0x0020},
 {0x000F86, 0x0010},
 {0x000F88, 0x0004},
 {0x000F8D, 0x0010},
-{0x000F98, 0x0080},
+{0x000F98, 0x0001},
 {0x000F99, 0x0010},
-{0x000FBD, 0x0080},
+{0x000FBD, 0x0001},
 {0x000FBE, 0x0040},
 {0x000FC6, 0x0010},
 {0x000FC7, 0x0040},
-{0x000FCD, 0x0080},
+{0x000FCD, 0x0001},
 {0x000FCE, 0x0040},
 {0x000FD0, 0x0020},
 {0x000FD5, 0x0040},
 {0x000FD9, 0x0020},
-{0x000FDB, 0x0080},
+{0x000FDB, 0x0001},
 {0x001000, 0x0004},
 {0x00102B, 0x0010},
 {0x00103F, 0x0004},
@@ -595,56 +599,56 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00109A, 0x0010},
 {0x00109E, 0x0040},
 {0x0010A0, 0x0004},
-{0x0010C6, 0x0080},
+{0x0010C6, 0x0001},
 {0x0010C7, 0x0004},
-{0x0010C8, 0x0080},
+{0x0010C8, 0x0001},
 {0x0010CD, 0x0004},
-{0x0010CE, 0x0080},
+{0x0010CE, 0x0001},
 {0x0010D0, 0x0004},
 {0x0010FB, 0x0020},
 {0x0010FC, 0x0004},
-{0x001249, 0x0080},
+{0x001249, 0x0001},
 {0x00124A, 0x0004},
-{0x00124E, 0x0080},
+{0x00124E, 0x0001},
 {0x001250, 0x0004},
-{0x001257, 0x0080},
+{0x001257, 0x0001},
 {0x001258, 0x0004},
-{0x001259, 0x0080},
+{0x001259, 0x0001},
 {0x00125A, 0x0004},
-{0x00125E, 0x0080},
+{0x00125E, 0x0001},
 {0x001260, 0x0004},
-{0x001289, 0x0080},
+{0x001289, 0x0001},
 {0x00128A, 0x0004},
-{0x00128E, 0x0080},
+{0x00128E, 0x0001},
 {0x001290, 0x0004},
-{0x0012B1, 0x0080},
+{0x0012B1, 0x0001},
 {0x0012B2, 0x0004},
-{0x0012B6, 0x0080},
+{0x0012B6, 0x0001},
 {0x0012B8, 0x0004},
-{0x0012BF, 0x0080},
+{0x0012BF, 0x0001},
 {0x0012C0, 0x0004},
-{0x0012C1, 0x0080},
+{0x0012C1, 0x0001},
 {0x0012C2, 0x0004},
-{0x0012C6, 0x0080},
+{0x0012C6, 0x0001},
 {0x0012C8, 0x0004},
-{0x0012D7, 0x0080},
+{0x0012D7, 0x0001},
 {0x0012D8, 0x0004},
-{0x001311, 0x0080},
+{0x001311, 0x0001},
 {0x001312, 0x0004},
-{0x001316, 0x0080},
+{0x001316, 0x0001},
 {0x001318, 0x0004},
-{0x00135B, 0x0080},
+{0x00135B, 0x0001},
 {0x00135D, 0x0010},
 {0x001360, 0x0020},
 {0x001369, 0x0002},
-{0x00137D, 0x0080},
+{0x00137D, 0x0001},
 {0x001380, 0x0004},
 {0x001390, 0x0040},
-{0x00139A, 0x0080},
+{0x00139A, 0x0001},
 {0x0013A0, 0x0004},
-{0x0013F6, 0x0080},
+{0x0013F6, 0x0001},
 {0x0013F8, 0x0004},
-{0x0013FE, 0x0080},
+{0x0013FE, 0x0001},
 {0x001400, 0x0020},
 {0x001401, 0x0004},
 {0x00166D, 0x0040},
@@ -653,28 +657,28 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x001680, 0x0008},
 {0x001681, 0x0004},
 {0x00169B, 0x0020},
-{0x00169D, 0x0080},
+{0x00169D, 0x0001},
 {0x0016A0, 0x0004},
 {0x0016EB, 0x0020},
 {0x0016EE, 0x0002},
 {0x0016F1, 0x0004},
-{0x0016F9, 0x0080},
+{0x0016F9, 0x0001},
 {0x001700, 0x0004},
 {0x001712, 0x0010},
-{0x001716, 0x0080},
+{0x001716, 0x0001},
 {0x00171F, 0x0004},
 {0x001732, 0x0010},
 {0x001735, 0x0020},
-{0x001737, 0x0080},
+{0x001737, 0x0001},
 {0x001740, 0x0004},
 {0x001752, 0x0010},
-{0x001754, 0x0080},
+{0x001754, 0x0001},
 {0x001760, 0x0004},
-{0x00176D, 0x0080},
+{0x00176D, 0x0001},
 {0x00176E, 0x0004},
-{0x001771, 0x0080},
+{0x001771, 0x0001},
 {0x001772, 0x0010},
-{0x001774, 0x0080},
+{0x001774, 0x0001},
 {0x001780, 0x0004},
 {0x0017B4, 0x0010},
 {0x0017D4, 0x0020},
@@ -683,80 +687,80 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0017DB, 0x0040},
 {0x0017DC, 0x0004},
 {0x0017DD, 0x0010},
-{0x0017DE, 0x0080},
+{0x0017DE, 0x0001},
 {0x0017E0, 0x0002},
-{0x0017EA, 0x0080},
+{0x0017EA, 0x0001},
 {0x0017F0, 0x0002},
-{0x0017FA, 0x0080},
+{0x0017FA, 0x0001},
 {0x001800, 0x0020},
 {0x00180B, 0x0010},
 {0x00180E, 0x0080},
 {0x00180F, 0x0010},
 {0x001810, 0x0002},
-{0x00181A, 0x0080},
+{0x00181A, 0x0001},
 {0x001820, 0x0004},
-{0x001879, 0x0080},
+{0x001879, 0x0001},
 {0x001880, 0x0004},
 {0x001885, 0x0010},
 {0x001887, 0x0004},
 {0x0018A9, 0x0010},
 {0x0018AA, 0x0004},
-{0x0018AB, 0x0080},
+{0x0018AB, 0x0001},
 {0x0018B0, 0x0004},
-{0x0018F6, 0x0080},
+{0x0018F6, 0x0001},
 {0x001900, 0x0004},
-{0x00191F, 0x0080},
+{0x00191F, 0x0001},
 {0x001920, 0x0010},
-{0x00192C, 0x0080},
+{0x00192C, 0x0001},
 {0x001930, 0x0010},
-{0x00193C, 0x0080},
+{0x00193C, 0x0001},
 {0x001940, 0x0040},
-{0x001941, 0x0080},
+{0x001941, 0x0001},
 {0x001944, 0x0020},
 {0x001946, 0x0002},
 {0x001950, 0x0004},
-{0x00196E, 0x0080},
+{0x00196E, 0x0001},
 {0x001970, 0x0004},
-{0x001975, 0x0080},
+{0x001975, 0x0001},
 {0x001980, 0x0004},
-{0x0019AC, 0x0080},
+{0x0019AC, 0x0001},
 {0x0019B0, 0x0004},
-{0x0019CA, 0x0080},
+{0x0019CA, 0x0001},
 {0x0019D0, 0x0002},
-{0x0019DB, 0x0080},
+{0x0019DB, 0x0001},
 {0x0019DE, 0x0040},
 {0x001A00, 0x0004},
 {0x001A17, 0x0010},
-{0x001A1C, 0x0080},
+{0x001A1C, 0x0001},
 {0x001A1E, 0x0020},
 {0x001A20, 0x0004},
 {0x001A55, 0x0010},
-{0x001A5F, 0x0080},
+{0x001A5F, 0x0001},
 {0x001A60, 0x0010},
-{0x001A7D, 0x0080},
+{0x001A7D, 0x0001},
 {0x001A7F, 0x0010},
 {0x001A80, 0x0002},
-{0x001A8A, 0x0080},
+{0x001A8A, 0x0001},
 {0x001A90, 0x0002},
-{0x001A9A, 0x0080},
+{0x001A9A, 0x0001},
 {0x001AA0, 0x0020},
 {0x001AA7, 0x0004},
 {0x001AA8, 0x0020},
-{0x001AAE, 0x0080},
+{0x001AAE, 0x0001},
 {0x001AB0, 0x0010},
-{0x001ACF, 0x0080},
+{0x001ACF, 0x0001},
 {0x001B00, 0x0010},
 {0x001B05, 0x0004},
 {0x001B34, 0x0010},
 {0x001B45, 0x0004},
-{0x001B4D, 0x0080},
+{0x001B4D, 0x0001},
 {0x001B50, 0x0002},
 {0x001B5A, 0x0020},
 {0x001B61, 0x0040},
 {0x001B6B, 0x0010},
 {0x001B74, 0x0040},
 {0x001B7D, 0x0020},
-{0x001B7F, 0x0080},
+{0x001B7F, 0x0001},
 {0x001B80, 0x0010},
 {0x001B83, 0x0004},
 {0x001BA1, 0x0010},
@@ -764,25 +768,25 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x001BB0, 0x0002},
 {0x001BBA, 0x0004},
 {0x001BE6, 0x0010},
-{0x001BF4, 0x0080},
+{0x001BF4, 0x0001},
 {0x001BFC, 0x0020},
 {0x001C00, 0x0004},
 {0x001C24, 0x0010},
-{0x001C38, 0x0080},
+{0x001C38, 0x0001},
 {0x001C3B, 0x0020},
 {0x001C40, 0x0002},
-{0x001C4A, 0x0080},
+{0x001C4A, 0x0001},
 {0x001C4D, 0x0004},
 {0x001C50, 0x0002},
 {0x001C5A, 0x0004},
 {0x001C7E, 0x0020},
 {0x001C80, 0x0004},
-{0x001C89, 0x0080},
+{0x001C89, 0x0001},
 {0x001C90, 0x0004},
-{0x001CBB, 0x0080},
+{0x001CBB, 0x0001},
 {0x001CBD, 0x0004},
 {0x001CC0, 0x0020},
-{0x001CC8, 0x0080},
+{0x001CC8, 0x0001},
 {0x001CD0, 0x0010},
 {0x001CD3, 0x0020},
 {0x001CD4, 0x0010},
@@ -793,50 +797,50 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x001CF5, 0x0004},
 {0x001CF7, 0x0010},
 {0x001CFA, 0x0004},
-{0x001CFB, 0x0080},
+{0x001CFB, 0x0001},
 {0x001D00, 0x0004},
 {0x001DC0, 0x0010},
 {0x001E00, 0x0004},
-{0x001F16, 0x0080},
+{0x001F16, 0x0001},
 {0x001F18, 0x0004},
-{0x001F1E, 0x0080},
+{0x001F1E, 0x0001},
 {0x001F20, 0x0004},
-{0x001F46, 0x0080},
+{0x001F46, 0x0001},
 {0x001F48, 0x0004},
-{0x001F4E, 0x0080},
+{0x001F4E, 0x0001},
 {0x001F50, 0x0004},
-{0x001F58, 0x0080},
+{0x001F58, 0x0001},
 {0x001F59, 0x0004},
-{0x001F5A, 0x0080},
+{0x001F5A, 0x0001},
 {0x001F5B, 0x0004},
-{0x001F5C, 0x0080},
+{0x001F5C, 0x0001},
 {0x001F5D, 0x0004},
-{0x001F5E, 0x0080},
+{0x001F5E, 0x0001},
 {0x001F5F, 0x0004},
-{0x001F7E, 0x0080},
+{0x001F7E, 0x0001},
 {0x001F80, 0x0004},
-{0x001FB5, 0x0080},
+{0x001FB5, 0x0001},
 {0x001FB6, 0x0004},
 {0x001FBD, 0x0040},
 {0x001FBE, 0x0004},
 {0x001FBF, 0x0040},
 {0x001FC2, 0x0004},
-{0x001FC5, 0x0080},
+{0x001FC5, 0x0001},
 {0x001FC6, 0x0004},
 {0x001FCD, 0x0040},
 {0x001FD0, 0x0004},
-{0x001FD4, 0x0080},
+{0x001FD4, 0x0001},
 {0x001FD6, 0x0004},
-{0x001FDC, 0x0080},
+{0x001FDC, 0x0001},
 {0x001FDD, 0x0040},
 {0x001FE0, 0x0004},
 {0x001FED, 0x0040},
-{0x001FF0, 0x0080},
+{0x001FF0, 0x0001},
 {0x001FF2, 0x0004},
-{0x001FF5, 0x0080},
+{0x001FF5, 0x0001},
 {0x001FF6, 0x0004},
 {0x001FFD, 0x0040},
-{0x001FFF, 0x0080},
+{0x001FFF, 0x0001},
 {0x002000, 0x0008},
 {0x00200B, 0x0080},
 {0x002010, 0x0020},
@@ -850,9 +854,11 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x002053, 0x0020},
 {0x00205F, 0x0008},
 {0x002060, 0x0080},
+{0x002065, 0x0001},
+{0x002066, 0x0080},
 {0x002070, 0x0002},
 {0x002071, 0x0004},
-{0x002072, 0x0080},
+{0x002072, 0x0001},
 {0x002074, 0x0002},
 {0x00207A, 0x0040},
 {0x00207D, 0x0020},
@@ -860,13 +866,13 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x002080, 0x0002},
 {0x00208A, 0x0040},
 {0x00208D, 0x0020},
-{0x00208F, 0x0080},
+{0x00208F, 0x0001},
 {0x002090, 0x0004},
-{0x00209D, 0x0080},
+{0x00209D, 0x0001},
 {0x0020A0, 0x0040},
-{0x0020C1, 0x0080},
+{0x0020C1, 0x0001},
 {0x0020D0, 0x0010},
-{0x0020F1, 0x0080},
+{0x0020F1, 0x0001},
 {0x002100, 0x0040},
 {0x002102, 0x0004},
 {0x002103, 0x0040},
@@ -898,15 +904,15 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x002183, 0x0004},
 {0x002185, 0x0002},
 {0x00218A, 0x0040},
-{0x00218C, 0x0080},
+{0x00218C, 0x0001},
 {0x002190, 0x0040},
 {0x002308, 0x0020},
 {0x00230C, 0x0040},
 {0x002329, 0x0020},
 {0x00232B, 0x0040},
-{0x002427, 0x0080},
+{0x002427, 0x0001},
 {0x002440, 0x0040},
-{0x00244B, 0x0080},
+{0x00244B, 0x0001},
 {0x002460, 0x0002},
 {0x00249C, 0x0040},
 {0x0024EA, 0x0002},
@@ -924,62 +930,62 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0029DC, 0x0040},
 {0x0029FC, 0x0020},
 {0x0029FE, 0x0040},
-{0x002B74, 0x0080},
+{0x002B74, 0x0001},
 {0x002B76, 0x0040},
-{0x002B96, 0x0080},
+{0x002B96, 0x0001},
 {0x002B97, 0x0040},
 {0x002C00, 0x0004},
 {0x002CE5, 0x0040},
 {0x002CEB, 0x0004},
 {0x002CEF, 0x0010},
 {0x002CF2, 0x0004},
-{0x002CF4, 0x0080},
+{0x002CF4, 0x0001},
 {0x002CF9, 0x0020},
 {0x002CFD, 0x0002},
 {0x002CFE, 0x0020},
 {0x002D00, 0x0004},
-{0x002D26, 0x0080},
+{0x002D26, 0x0001},
 {0x002D27, 0x0004},
-{0x002D28, 0x0080},
+{0x002D28, 0x0001},
 {0x002D2D, 0x0004},
-{0x002D2E, 0x0080},
+{0x002D2E, 0x0001},
 {0x002D30, 0x0004},
-{0x002D68, 0x0080},
+{0x002D68, 0x0001},
 {0x002D6F, 0x0004},
 {0x002D70, 0x0020},
-{0x002D71, 0x0080},
+{0x002D71, 0x0001},
 {0x002D7F, 0x0010},
 {0x002D80, 0x0004},
-{0x002D97, 0x0080},
+{0x002D97, 0x0001},
 {0x002DA0, 0x0004},
-{0x002DA7, 0x0080},
+{0x002DA7, 0x0001},
 {0x002DA8, 0x0004},
-{0x002DAF, 0x0080},
+{0x002DAF, 0x0001},
 {0x002DB0, 0x0004},
-{0x002DB7, 0x0080},
+{0x002DB7, 0x0001},
 {0x002DB8, 0x0004},
-{0x002DBF, 0x0080},
+{0x002DBF, 0x0001},
 {0x002DC0, 0x0004},
-{0x002DC7, 0x0080},
+{0x002DC7, 0x0001},
 {0x002DC8, 0x0004},
-{0x002DCF, 0x0080},
+{0x002DCF, 0x0001},
 {0x002DD0, 0x0004},
-{0x002DD7, 0x0080},
+{0x002DD7, 0x0001},
 {0x002DD8, 0x0004},
-{0x002DDF, 0x0080},
+{0x002DDF, 0x0001},
 {0x002DE0, 0x0010},
 {0x002E00, 0x0020},
 {0x002E2F, 0x0004},
 {0x002E30, 0x0020},
 {0x002E50, 0x0040},
 {0x002E52, 0x0020},
-{0x002E5E, 0x0080},
+{0x002E5E, 0x0001},
 {0x002E80, 0x0040},
-{0x002E9A, 0x0080},
+{0x002E9A, 0x0001},
 {0x002E9B, 0x0040},
-{0x002EF4, 0x0080},
+{0x002EF4, 0x0001},
 {0x002F00, 0x0040},
-{0x002FD6, 0x0080},
+{0x002FD6, 0x0001},
 {0x002FF0, 0x0040},
 {0x003000, 0x0008},
 {0x003001, 0x0020},
@@ -999,9 +1005,9 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00303B, 0x0004},
 {0x00303D, 0x0020},
 {0x00303E, 0x0040},
-{0x003040, 0x0080},
+{0x003040, 0x0001},
 {0x003041, 0x0004},
-{0x003097, 0x0080},
+{0x003097, 0x0001},
 {0x003099, 0x0010},
 {0x00309B, 0x0040},
 {0x00309D, 0x0004},
@@ -1009,21 +1015,21 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0030A1, 0x0004},
 {0x0030FB, 0x0020},
 {0x0030FC, 0x0004},
-{0x003100, 0x0080},
+{0x003100, 0x0001},
 {0x003105, 0x0004},
-{0x003130, 0x0080},
+{0x003130, 0x0001},
 {0x003131, 0x0004},
-{0x00318F, 0x0080},
+{0x00318F, 0x0001},
 {0x003190, 0x0040},
 {0x003192, 0x0002},
 {0x003196, 0x0040},
 {0x0031A0, 0x0004},
 {0x0031C0, 0x0040},
-{0x0031E4, 0x0080},
+{0x0031E4, 0x0001},
 {0x0031EF, 0x0040},
 {0x0031F0, 0x0004},
 {0x003200, 0x0040},
-{0x00321F, 0x0080},
+{0x00321F, 0x0001},
 {0x003220, 0x0002},
 {0x00322A, 0x0040},
 {0x003248, 0x0002},
@@ -1037,9 +1043,9 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x003400, 0x0004},
 {0x004DC0, 0x0040},
 {0x004E00, 0x0004},
-{0x00A48D, 0x0080},
+{0x00A48D, 0x0001},
 {0x00A490, 0x0040},
-{0x00A4C7, 0x0080},
+{0x00A4C7, 0x0001},
 {0x00A4D0, 0x0004},
 {0x00A4FE, 0x0020},
 {0x00A500, 0x0004},
@@ -1047,7 +1053,7 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00A610, 0x0004},
 {0x00A620, 0x0002},
 {0x00A62A, 0x0004},
-{0x00A62C, 0x0080},
+{0x00A62C, 0x0001},
 {0x00A640, 0x0004},
 {0x00A66F, 0x0010},
 {0x00A673, 0x0020},
@@ -1059,20 +1065,20 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00A6E6, 0x0002},
 {0x00A6F0, 0x0010},
 {0x00A6F2, 0x0020},
-{0x00A6F8, 0x0080},
+{0x00A6F8, 0x0001},
 {0x00A700, 0x0040},
 {0x00A717, 0x0004},
 {0x00A720, 0x0040},
 {0x00A722, 0x0004},
 {0x00A789, 0x0040},
 {0x00A78B, 0x0004},
-{0x00A7CB, 0x0080},
+{0x00A7CB, 0x0001},
 {0x00A7D0, 0x0004},
-{0x00A7D2, 0x0080},
+{0x00A7D2, 0x0001},
 {0x00A7D3, 0x0004},
-{0x00A7D4, 0x0080},
+{0x00A7D4, 0x0001},
 {0x00A7D5, 0x0004},
-{0x00A7DA, 0x0080},
+{0x00A7DA, 0x0001},
 {0x00A7F2, 0x0004},
 {0x00A802, 0x0010},
 {0x00A803, 0x0004},
@@ -1083,20 +1089,20 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00A823, 0x0010},
 {0x00A828, 0x0040},
 {0x00A82C, 0x0010},
-{0x00A82D, 0x0080},
+{0x00A82D, 0x0001},
 {0x00A830, 0x0002},
 {0x00A836, 0x0040},
-{0x00A83A, 0x0080},
+{0x00A83A, 0x0001},
 {0x00A840, 0x0004},
 {0x00A874, 0x0020},
-{0x00A878, 0x0080},
+{0x00A878, 0x0001},
 {0x00A880, 0x0010},
 {0x00A882, 0x0004},
 {0x00A8B4, 0x0010},
-{0x00A8C6, 0x0080},
+{0x00A8C6, 0x0001},
 {0x00A8CE, 0x0020},
 {0x00A8D0, 0x0002},
-{0x00A8DA, 0x0080},
+{0x00A8DA, 0x0001},
 {0x00A8E0, 0x0010},
 {0x00A8F2, 0x0004},
 {0x00A8F8, 0x0020},
@@ -1110,35 +1116,35 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00A92E, 0x0020},
 {0x00A930, 0x0004},
 {0x00A947, 0x0010},
-{0x00A954, 0x0080},
+{0x00A954, 0x0001},
 {0x00A95F, 0x0020},
 {0x00A960, 0x0004},
-{0x00A97D, 0x0080},
+{0x00A97D, 0x0001},
 {0x00A980, 0x0010},
 {0x00A984, 0x0004},
 {0x00A9B3, 0x0010},
 {0x00A9C1, 0x0020},
-{0x00A9CE, 0x0080},
+{0x00A9CE, 0x0001},
 {0x00A9CF, 0x0004},
 {0x00A9D0, 0x0002},
-{0x00A9DA, 0x0080},
+{0x00A9DA, 0x0001},
 {0x00A9DE, 0x0020},
 {0x00A9E0, 0x0004},
 {0x00A9E5, 0x0010},
 {0x00A9E6, 0x0004},
 {0x00A9F0, 0x0002},
 {0x00A9FA, 0x0004},
-{0x00A9FF, 0x0080},
+{0x00A9FF, 0x0001},
 {0x00AA00, 0x0004},
 {0x00AA29, 0x0010},
-{0x00AA37, 0x0080},
+{0x00AA37, 0x0001},
 {0x00AA40, 0x0004},
 {0x00AA43, 0x0010},
 {0x00AA44, 0x0004},
 {0x00AA4C, 0x0010},
-{0x00AA4E, 0x0080},
+{0x00AA4E, 0x0001},
 {0x00AA50, 0x0002},
-{0x00AA5A, 0x0080},
+{0x00AA5A, 0x0001},
 {0x00AA5C, 0x0020},
 {0x00AA60, 0x0004},
 {0x00AA77, 0x0040},
@@ -1155,7 +1161,7 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00AAC0, 0x0004},
 {0x00AAC1, 0x0010},
 {0x00AAC2, 0x0004},
-{0x00AAC3, 0x0080},
+{0x00AAC3, 0x0001},
 {0x00AADB, 0x0004},
 {0x00AADE, 0x0020},
 {0x00AAE0, 0x0004},
@@ -1163,90 +1169,93 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00AAF0, 0x0020},
 {0x00AAF2, 0x0004},
 {0x00AAF5, 0x0010},
-{0x00AAF7, 0x0080},
+{0x00AAF7, 0x0001},
 {0x00AB01, 0x0004},
-{0x00AB07, 0x0080},
+{0x00AB07, 0x0001},
 {0x00AB09, 0x0004},
-{0x00AB0F, 0x0080},
+{0x00AB0F, 0x0001},
 {0x00AB11, 0x0004},
-{0x00AB17, 0x0080},
+{0x00AB17, 0x0001},
 {0x00AB20, 0x0004},
-{0x00AB27, 0x0080},
+{0x00AB27, 0x0001},
 {0x00AB28, 0x0004},
-{0x00AB2F, 0x0080},
+{0x00AB2F, 0x0001},
 {0x00AB30, 0x0004},
 {0x00AB5B, 0x0040},
 {0x00AB5C, 0x0004},
 {0x00AB6A, 0x0040},
-{0x00AB6C, 0x0080},
+{0x00AB6C, 0x0001},
 {0x00AB70, 0x0004},
 {0x00ABE3, 0x0010},
 {0x00ABEB, 0x0020},
 {0x00ABEC, 0x0010},
-{0x00ABEE, 0x0080},
+{0x00ABEE, 0x0001},
 {0x00ABF0, 0x0002},
-{0x00ABFA, 0x0080},
+{0x00ABFA, 0x0001},
 {0x00AC00, 0x0004},
-{0x00D7A4, 0x0080},
+{0x00D7A4, 0x0001},
 {0x00D7B0, 0x0004},
-{0x00D7C7, 0x0080},
+{0x00D7C7, 0x0001},
 {0x00D7CB, 0x0004},
-{0x00D7FC, 0x0080},
+{0x00D7FC, 0x0001},
+{0x00D800, 0x0080},
 {0x00F900, 0x0004},
-{0x00FA6E, 0x0080},
+{0x00FA6E, 0x0001},
 {0x00FA70, 0x0004},
-{0x00FADA, 0x0080},
+{0x00FADA, 0x0001},
 {0x00FB00, 0x0004},
-{0x00FB07, 0x0080},
+{0x00FB07, 0x0001},
 {0x00FB13, 0x0004},
-{0x00FB18, 0x0080},
+{0x00FB18, 0x0001},
 {0x00FB1D, 0x0004},
 {0x00FB1E, 0x0010},
 {0x00FB1F, 0x0004},
 {0x00FB29, 0x0040},
 {0x00FB2A, 0x0004},
-{0x00FB37, 0x0080},
+{0x00FB37, 0x0001},
 {0x00FB38, 0x0004},
-{0x00FB3D, 0x0080},
+{0x00FB3D, 0x0001},
 {0x00FB3E, 0x0004},
-{0x00FB3F, 0x0080},
+{0x00FB3F, 0x0001},
 {0x00FB40, 0x0004},
-{0x00FB42, 0x0080},
+{0x00FB42, 0x0001},
 {0x00FB43, 0x0004},
-{0x00FB45, 0x0080},
+{0x00FB45, 0x0001},
 {0x00FB46, 0x0004},
 {0x00FBB2, 0x0040},
-{0x00FBC3, 0x0080},
+{0x00FBC3, 0x0001},
 {0x00FBD3, 0x0004},
 {0x00FD3E, 0x0020},
 {0x00FD40, 0x0040},
 {0x00FD50, 0x0004},
-{0x00FD90, 0x0080},
+{0x00FD90, 0x0001},
 {0x00FD92, 0x0004},
-{0x00FDC8, 0x0080},
+{0x00FDC8, 0x0001},
 {0x00FDCF, 0x0040},
-{0x00FDD0, 0x0080},
+{0x00FDD0, 0x0001},
 {0x00FDF0, 0x0004},
 {0x00FDFC, 0x0040},
 {0x00FE00, 0x0010},
 {0x00FE10, 0x0020},
-{0x00FE1A, 0x0080},
+{0x00FE1A, 0x0001},
 {0x00FE20, 0x0010},
 {0x00FE30, 0x0020},
-{0x00FE53, 0x0080},
+{0x00FE53, 0x0001},
 {0x00FE54, 0x0020},
 {0x00FE62, 0x0040},
 {0x00FE63, 0x0020},
 {0x00FE64, 0x0040},
-{0x00FE67, 0x0080},
+{0x00FE67, 0x0001},
 {0x00FE68, 0x0020},
 {0x00FE69, 0x0040},
 {0x00FE6A, 0x0020},
-{0x00FE6C, 0x0080},
+{0x00FE6C, 0x0001},
 {0x00FE70, 0x0004},
-{0x00FE75, 0x0080},
+{0x00FE75, 0x0001},
 {0x00FE76, 0x0004},
-{0x00FEFD, 0x0080},
+{0x00FEFD, 0x0001},
+{0x00FEFF, 0x0080},
+{0x00FF00, 0x0001},
 {0x00FF01, 0x0020},
 {0x00FF04, 0x0040},
 {0x00FF05, 0x0020},
@@ -1268,260 +1277,261 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00FF5E, 0x0040},
 {0x00FF5F, 0x0020},
 {0x00FF66, 0x0004},
-{0x00FFBF, 0x0080},
+{0x00FFBF, 0x0001},
 {0x00FFC2, 0x0004},
-{0x00FFC8, 0x0080},
+{0x00FFC8, 0x0001},
 {0x00FFCA, 0x0004},
-{0x00FFD0, 0x0080},
+{0x00FFD0, 0x0001},
 {0x00FFD2, 0x0004},
-{0x00FFD8, 0x0080},
+{0x00FFD8, 0x0001},
 {0x00FFDA, 0x0004},
-{0x00FFDD, 0x0080},
+{0x00FFDD, 0x0001},
 {0x00FFE0, 0x0040},
-{0x00FFE7, 0x0080},
+{0x00FFE7, 0x0001},
 {0x00FFE8, 0x0040},
-{0x00FFEF, 0x0080},
+{0x00FFEF, 0x0001},
+{0x00FFF9, 0x0080},
 {0x00FFFC, 0x0040},
-{0x00FFFE, 0x0080},
+{0x00FFFE, 0x0001},
 {0x010000, 0x0004},
-{0x01000C, 0x0080},
+{0x01000C, 0x0001},
 {0x01000D, 0x0004},
-{0x010027, 0x0080},
+{0x010027, 0x0001},
 {0x010028, 0x0004},
-{0x01003B, 0x0080},
+{0x01003B, 0x0001},
 {0x01003C, 0x0004},
-{0x01003E, 0x0080},
+{0x01003E, 0x0001},
 {0x01003F, 0x0004},
-{0x01004E, 0x0080},
+{0x01004E, 0x0001},
 {0x010050, 0x0004},
-{0x01005E, 0x0080},
+{0x01005E, 0x0001},
 {0x010080, 0x0004},
-{0x0100FB, 0x0080},
+{0x0100FB, 0x0001},
 {0x010100, 0x0020},
-{0x010103, 0x0080},
+{0x010103, 0x0001},
 {0x010107, 0x0002},
-{0x010134, 0x0080},
+{0x010134, 0x0001},
 {0x010137, 0x0040},
 {0x010140, 0x0002},
 {0x010179, 0x0040},
 {0x01018A, 0x0002},
 {0x01018C, 0x0040},
-{0x01018F, 0x0080},
+{0x01018F, 0x0001},
 {0x010190, 0x0040},
-{0x01019D, 0x0080},
+{0x01019D, 0x0001},
 {0x0101A0, 0x0040},
-{0x0101A1, 0x0080},
+{0x0101A1, 0x0001},
 {0x0101D0, 0x0040},
 {0x0101FD, 0x0010},
-{0x0101FE, 0x0080},
+{0x0101FE, 0x0001},
 {0x010280, 0x0004},
-{0x01029D, 0x0080},
+{0x01029D, 0x0001},
 {0x0102A0, 0x0004},
-{0x0102D1, 0x0080},
+{0x0102D1, 0x0001},
 {0x0102E0, 0x0010},
 {0x0102E1, 0x0002},
-{0x0102FC, 0x0080},
+{0x0102FC, 0x0001},
 {0x010300, 0x0004},
 {0x010320, 0x0002},
-{0x010324, 0x0080},
+{0x010324, 0x0001},
 {0x01032D, 0x0004},
 {0x010341, 0x0002},
 {0x010342, 0x0004},
 {0x01034A, 0x0002},
-{0x01034B, 0x0080},
+{0x01034B, 0x0001},
 {0x010350, 0x0004},
 {0x010376, 0x0010},
-{0x01037B, 0x0080},
+{0x01037B, 0x0001},
 {0x010380, 0x0004},
-{0x01039E, 0x0080},
+{0x01039E, 0x0001},
 {0x01039F, 0x0020},
 {0x0103A0, 0x0004},
-{0x0103C4, 0x0080},
+{0x0103C4, 0x0001},
 {0x0103C8, 0x0004},
 {0x0103D0, 0x0020},
 {0x0103D1, 0x0002},
-{0x0103D6, 0x0080},
+{0x0103D6, 0x0001},
 {0x010400, 0x0004},
-{0x01049E, 0x0080},
+{0x01049E, 0x0001},
 {0x0104A0, 0x0002},
-{0x0104AA, 0x0080},
+{0x0104AA, 0x0001},
 {0x0104B0, 0x0004},
-{0x0104D4, 0x0080},
+{0x0104D4, 0x0001},
 {0x0104D8, 0x0004},
-{0x0104FC, 0x0080},
+{0x0104FC, 0x0001},
 {0x010500, 0x0004},
-{0x010528, 0x0080},
+{0x010528, 0x0001},
 {0x010530, 0x0004},
-{0x010564, 0x0080},
+{0x010564, 0x0001},
 {0x01056F, 0x0020},
 {0x010570, 0x0004},
-{0x01057B, 0x0080},
+{0x01057B, 0x0001},
 {0x01057C, 0x0004},
-{0x01058B, 0x0080},
+{0x01058B, 0x0001},
 {0x01058C, 0x0004},
-{0x010593, 0x0080},
+{0x010593, 0x0001},
 {0x010594, 0x0004},
-{0x010596, 0x0080},
+{0x010596, 0x0001},
 {0x010597, 0x0004},
-{0x0105A2, 0x0080},
+{0x0105A2, 0x0001},
 {0x0105A3, 0x0004},
-{0x0105B2, 0x0080},
+{0x0105B2, 0x0001},
 {0x0105B3, 0x0004},
-{0x0105BA, 0x0080},
+{0x0105BA, 0x0001},
 {0x0105BB, 0x0004},
-{0x0105BD, 0x0080},
+{0x0105BD, 0x0001},
 {0x010600, 0x0004},
-{0x010737, 0x0080},
+{0x010737, 0x0001},
 {0x010740, 0x0004},
-{0x010756, 0x0080},
+{0x010756, 0x0001},
 {0x010760, 0x0004},
-{0x010768, 0x0080},
+{0x010768, 0x0001},
 {0x010780, 0x0004},
-{0x010786, 0x0080},
+{0x010786, 0x0001},
 {0x010787, 0x0004},
-{0x0107B1, 0x0080},
+{0x0107B1, 0x0001},
 {0x0107B2, 0x0004},
-{0x0107BB, 0x0080},
+{0x0107BB, 0x0001},
 {0x010800, 0x0004},
-{0x010806, 0x0080},
+{0x010806, 0x0001},
 {0x010808, 0x0004},
-{0x010809, 0x0080},
+{0x010809, 0x0001},
 {0x01080A, 0x0004},
-{0x010836, 0x0080},
+{0x010836, 0x0001},
 {0x010837, 0x0004},
-{0x010839, 0x0080},
+{0x010839, 0x0001},
 {0x01083C, 0x0004},
-{0x01083D, 0x0080},
+{0x01083D, 0x0001},
 {0x01083F, 0x0004},
-{0x010856, 0x0080},
+{0x010856, 0x0001},
 {0x010857, 0x0020},
 {0x010858, 0x0002},
 {0x010860, 0x0004},
 {0x010877, 0x0040},
 {0x010879, 0x0002},
 {0x010880, 0x0004},
-{0x01089F, 0x0080},
+{0x01089F, 0x0001},
 {0x0108A7, 0x0002},
-{0x0108B0, 0x0080},
+{0x0108B0, 0x0001},
 {0x0108E0, 0x0004},
-{0x0108F3, 0x0080},
+{0x0108F3, 0x0001},
 {0x0108F4, 0x0004},
-{0x0108F6, 0x0080},
+{0x0108F6, 0x0001},
 {0x0108FB, 0x0002},
 {0x010900, 0x0004},
 {0x010916, 0x0002},
-{0x01091C, 0x0080},
+{0x01091C, 0x0001},
 {0x01091F, 0x0020},
 {0x010920, 0x0004},
-{0x01093A, 0x0080},
+{0x01093A, 0x0001},
 {0x01093F, 0x0020},
-{0x010940, 0x0080},
+{0x010940, 0x0001},
 {0x010980, 0x0004},
-{0x0109B8, 0x0080},
+{0x0109B8, 0x0001},
 {0x0109BC, 0x0002},
 {0x0109BE, 0x0004},
 {0x0109C0, 0x0002},
-{0x0109D0, 0x0080},
+{0x0109D0, 0x0001},
 {0x0109D2, 0x0002},
 {0x010A00, 0x0004},
 {0x010A01, 0x0010},
-{0x010A04, 0x0080},
+{0x010A04, 0x0001},
 {0x010A05, 0x0010},
-{0x010A07, 0x0080},
+{0x010A07, 0x0001},
 {0x010A0C, 0x0010},
 {0x010A10, 0x0004},
-{0x010A14, 0x0080},
+{0x010A14, 0x0001},
 {0x010A15, 0x0004},
-{0x010A18, 0x0080},
+{0x010A18, 0x0001},
 {0x010A19, 0x0004},
-{0x010A36, 0x0080},
+{0x010A36, 0x0001},
 {0x010A38, 0x0010},
-{0x010A3B, 0x0080},
+{0x010A3B, 0x0001},
 {0x010A3F, 0x0010},
 {0x010A40, 0x0002},
-{0x010A49, 0x0080},
+{0x010A49, 0x0001},
 {0x010A50, 0x0020},
-{0x010A59, 0x0080},
+{0x010A59, 0x0001},
 {0x010A60, 0x0004},
 {0x010A7D, 0x0002},
 {0x010A7F, 0x0020},
 {0x010A80, 0x0004},
 {0x010A9D, 0x0002},
-{0x010AA0, 0x0080},
+{0x010AA0, 0x0001},
 {0x010AC0, 0x0004},
 {0x010AC8, 0x0040},
 {0x010AC9, 0x0004},
 {0x010AE5, 0x0010},
-{0x010AE7, 0x0080},
+{0x010AE7, 0x0001},
 {0x010AEB, 0x0002},
 {0x010AF0, 0x0020},
-{0x010AF7, 0x0080},
+{0x010AF7, 0x0001},
 {0x010B00, 0x0004},
-{0x010B36, 0x0080},
+{0x010B36, 0x0001},
 {0x010B39, 0x0020},
 {0x010B40, 0x0004},
-{0x010B56, 0x0080},
+{0x010B56, 0x0001},
 {0x010B58, 0x0002},
 {0x010B60, 0x0004},
-{0x010B73, 0x0080},
+{0x010B73, 0x0001},
 {0x010B78, 0x0002},
 {0x010B80, 0x0004},
-{0x010B92, 0x0080},
+{0x010B92, 0x0001},
 {0x010B99, 0x0020},
-{0x010B9D, 0x0080},
+{0x010B9D, 0x0001},
 {0x010BA9, 0x0002},
-{0x010BB0, 0x0080},
+{0x010BB0, 0x0001},
 {0x010C00, 0x0004},
-{0x010C49, 0x0080},
+{0x010C49, 0x0001},
 {0x010C80, 0x0004},
-{0x010CB3, 0x0080},
+{0x010CB3, 0x0001},
 {0x010CC0, 0x0004},
-{0x010CF3, 0x0080},
+{0x010CF3, 0x0001},
 {0x010CFA, 0x0002},
 {0x010D00, 0x0004},
 {0x010D24, 0x0010},
-{0x010D28, 0x0080},
+{0x010D28, 0x0001},
 {0x010D30, 0x0002},
-{0x010D3A, 0x0080},
+{0x010D3A, 0x0001},
 {0x010E60, 0x0002},
-{0x010E7F, 0x0080},
+{0x010E7F, 0x0001},
 {0x010E80, 0x0004},
-{0x010EAA, 0x0080},
+{0x010EAA, 0x0001},
 {0x010EAB, 0x0010},
 {0x010EAD, 0x0020},
-{0x010EAE, 0x0080},
+{0x010EAE, 0x0001},
 {0x010EB0, 0x0004},
-{0x010EB2, 0x0080},
+{0x010EB2, 0x0001},
 {0x010EFD, 0x0010},
 {0x010F00, 0x0004},
 {0x010F1D, 0x0002},
 {0x010F27, 0x0004},
-{0x010F28, 0x0080},
+{0x010F28, 0x0001},
 {0x010F30, 0x0004},
 {0x010F46, 0x0010},
 {0x010F51, 0x0002},
 {0x010F55, 0x0020},
-{0x010F5A, 0x0080},
+{0x010F5A, 0x0001},
 {0x010F70, 0x0004},
 {0x010F82, 0x0010},
 {0x010F86, 0x0020},
-{0x010F8A, 0x0080},
+{0x010F8A, 0x0001},
 {0x010FB0, 0x0004},
 {0x010FC5, 0x0002},
-{0x010FCC, 0x0080},
+{0x010FCC, 0x0001},
 {0x010FE0, 0x0004},
-{0x010FF7, 0x0080},
+{0x010FF7, 0x0001},
 {0x011000, 0x0010},
 {0x011003, 0x0004},
 {0x011038, 0x0010},
 {0x011047, 0x0020},
-{0x01104E, 0x0080},
+{0x01104E, 0x0001},
 {0x011052, 0x0002},
 {0x011070, 0x0010},
 {0x011071, 0x0004},
 {0x011073, 0x0010},
 {0x011075, 0x0004},
-{0x011076, 0x0080},
+{0x011076, 0x0001},
 {0x01107F, 0x0010},
 {0x011083, 0x0004},
 {0x0110B0, 0x0010},
@@ -1529,26 +1539,28 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0110BD, 0x0080},
 {0x0110BE, 0x0020},
 {0x0110C2, 0x0010},
-{0x0110C3, 0x0080},
+{0x0110C3, 0x0001},
+{0x0110CD, 0x0080},
+{0x0110CE, 0x0001},
 {0x0110D0, 0x0004},
-{0x0110E9, 0x0080},
+{0x0110E9, 0x0001},
 {0x0110F0, 0x0002},
-{0x0110FA, 0x0080},
+{0x0110FA, 0x0001},
 {0x011100, 0x0010},
 {0x011103, 0x0004},
 {0x011127, 0x0010},
-{0x011135, 0x0080},
+{0x011135, 0x0001},
 {0x011136, 0x0002},
 {0x011140, 0x0020},
 {0x011144, 0x0004},
 {0x011145, 0x0010},
 {0x011147, 0x0004},
-{0x011148, 0x0080},
+{0x011148, 0x0001},
 {0x011150, 0x0004},
 {0x011173, 0x0010},
 {0x011174, 0x0020},
 {0x011176, 0x0004},
-{0x011177, 0x0080},
+{0x011177, 0x0001},
 {0x011180, 0x0010},
 {0x011183, 0x0004},
 {0x0111B3, 0x0010},
@@ -1562,159 +1574,159 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0111DB, 0x0020},
 {0x0111DC, 0x0004},
 {0x0111DD, 0x0020},
-{0x0111E0, 0x0080},
+{0x0111E0, 0x0001},
 {0x0111E1, 0x0002},
-{0x0111F5, 0x0080},
+{0x0111F5, 0x0001},
 {0x011200, 0x0004},
-{0x011212, 0x0080},
+{0x011212, 0x0001},
 {0x011213, 0x0004},
 {0x01122C, 0x0010},
 {0x011238, 0x0020},
 {0x01123E, 0x0010},
 {0x01123F, 0x0004},
 {0x011241, 0x0010},
-{0x011242, 0x0080},
+{0x011242, 0x0001},
 {0x011280, 0x0004},
-{0x011287, 0x0080},
+{0x011287, 0x0001},
 {0x011288, 0x0004},
-{0x011289, 0x0080},
+{0x011289, 0x0001},
 {0x01128A, 0x0004},
-{0x01128E, 0x0080},
+{0x01128E, 0x0001},
 {0x01128F, 0x0004},
-{0x01129E, 0x0080},
+{0x01129E, 0x0001},
 {0x01129F, 0x0004},
 {0x0112A9, 0x0020},
-{0x0112AA, 0x0080},
+{0x0112AA, 0x0001},
 {0x0112B0, 0x0004},
 {0x0112DF, 0x0010},
-{0x0112EB, 0x0080},
+{0x0112EB, 0x0001},
 {0x0112F0, 0x0002},
-{0x0112FA, 0x0080},
+{0x0112FA, 0x0001},
 {0x011300, 0x0010},
-{0x011304, 0x0080},
+{0x011304, 0x0001},
 {0x011305, 0x0004},
-{0x01130D, 0x0080},
+{0x01130D, 0x0001},
 {0x01130F, 0x0004},
-{0x011311, 0x0080},
+{0x011311, 0x0001},
 {0x011313, 0x0004},
-{0x011329, 0x0080},
+{0x011329, 0x0001},
 {0x01132A, 0x0004},
-{0x011331, 0x0080},
+{0x011331, 0x0001},
 {0x011332, 0x0004},
-{0x011334, 0x0080},
+{0x011334, 0x0001},
 {0x011335, 0x0004},
-{0x01133A, 0x0080},
+{0x01133A, 0x0001},
 {0x01133B, 0x0010},
 {0x01133D, 0x0004},
 {0x01133E, 0x0010},
-{0x011345, 0x0080},
+{0x011345, 0x0001},
 {0x011347, 0x0010},
-{0x011349, 0x0080},
+{0x011349, 0x0001},
 {0x01134B, 0x0010},
-{0x01134E, 0x0080},
+{0x01134E, 0x0001},
 {0x011350, 0x0004},
-{0x011351, 0x0080},
+{0x011351, 0x0001},
 {0x011357, 0x0010},
-{0x011358, 0x0080},
+{0x011358, 0x0001},
 {0x01135D, 0x0004},
 {0x011362, 0x0010},
-{0x011364, 0x0080},
+{0x011364, 0x0001},
 {0x011366, 0x0010},
-{0x01136D, 0x0080},
+{0x01136D, 0x0001},
 {0x011370, 0x0010},
-{0x011375, 0x0080},
+{0x011375, 0x0001},
 {0x011400, 0x0004},
 {0x011435, 0x0010},
 {0x011447, 0x0004},
 {0x01144B, 0x0020},
 {0x011450, 0x0002},
 {0x01145A, 0x0020},
-{0x01145C, 0x0080},
+{0x01145C, 0x0001},
 {0x01145D, 0x0020},
 {0x01145E, 0x0010},
 {0x01145F, 0x0004},
-{0x011462, 0x0080},
+{0x011462, 0x0001},
 {0x011480, 0x0004},
 {0x0114B0, 0x0010},
 {0x0114C4, 0x0004},
 {0x0114C6, 0x0020},
 {0x0114C7, 0x0004},
-{0x0114C8, 0x0080},
+{0x0114C8, 0x0001},
 {0x0114D0, 0x0002},
-{0x0114DA, 0x0080},
+{0x0114DA, 0x0001},
 {0x011580, 0x0004},
 {0x0115AF, 0x0010},
-{0x0115B6, 0x0080},
+{0x0115B6, 0x0001},
 {0x0115B8, 0x0010},
 {0x0115C1, 0x0020},
 {0x0115D8, 0x0004},
 {0x0115DC, 0x0010},
-{0x0115DE, 0x0080},
+{0x0115DE, 0x0001},
 {0x011600, 0x0004},
 {0x011630, 0x0010},
 {0x011641, 0x0020},
 {0x011644, 0x0004},
-{0x011645, 0x0080},
+{0x011645, 0x0001},
 {0x011650, 0x0002},
-{0x01165A, 0x0080},
+{0x01165A, 0x0001},
 {0x011660, 0x0020},
-{0x01166D, 0x0080},
+{0x01166D, 0x0001},
 {0x011680, 0x0004},
 {0x0116AB, 0x0010},
 {0x0116B8, 0x0004},
 {0x0116B9, 0x0020},
-{0x0116BA, 0x0080},
+{0x0116BA, 0x0001},
 {0x0116C0, 0x0002},
-{0x0116CA, 0x0080},
+{0x0116CA, 0x0001},
 {0x011700, 0x0004},
-{0x01171B, 0x0080},
+{0x01171B, 0x0001},
 {0x01171D, 0x0010},
-{0x01172C, 0x0080},
+{0x01172C, 0x0001},
 {0x011730, 0x0002},
 {0x01173C, 0x0020},
 {0x01173F, 0x0040},
 {0x011740, 0x0004},
-{0x011747, 0x0080},
+{0x011747, 0x0001},
 {0x011800, 0x0004},
 {0x01182C, 0x0010},
 {0x01183B, 0x0020},
-{0x01183C, 0x0080},
+{0x01183C, 0x0001},
 {0x0118A0, 0x0004},
 {0x0118E0, 0x0002},
-{0x0118F3, 0x0080},
+{0x0118F3, 0x0001},
 {0x0118FF, 0x0004},
-{0x011907, 0x0080},
+{0x011907, 0x0001},
 {0x011909, 0x0004},
-{0x01190A, 0x0080},
+{0x01190A, 0x0001},
 {0x01190C, 0x0004},
-{0x011914, 0x0080},
+{0x011914, 0x0001},
 {0x011915, 0x0004},
-{0x011917, 0x0080},
+{0x011917, 0x0001},
 {0x011918, 0x0004},
 {0x011930, 0x0010},
-{0x011936, 0x0080},
+{0x011936, 0x0001},
 {0x011937, 0x0010},
-{0x011939, 0x0080},
+{0x011939, 0x0001},
 {0x01193B, 0x0010},
 {0x01193F, 0x0004},
 {0x011940, 0x0010},
 {0x011941, 0x0004},
 {0x011942, 0x0010},
 {0x011944, 0x0020},
-{0x011947, 0x0080},
+{0x011947, 0x0001},
 {0x011950, 0x0002},
-{0x01195A, 0x0080},
+{0x01195A, 0x0001},
 {0x0119A0, 0x0004},
-{0x0119A8, 0x0080},
+{0x0119A8, 0x0001},
 {0x0119AA, 0x0004},
 {0x0119D1, 0x0010},
-{0x0119D8, 0x0080},
+{0x0119D8, 0x0001},
 {0x0119DA, 0x0010},
 {0x0119E1, 0x0004},
 {0x0119E2, 0x0020},
 {0x0119E3, 0x0004},
 {0x0119E4, 0x0010},
-{0x0119E5, 0x0080},
+{0x0119E5, 0x0001},
 {0x011A00, 0x0004},
 {0x011A01, 0x0010},
 {0x011A0B, 0x0004},
@@ -1723,7 +1735,7 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x011A3B, 0x0010},
 {0x011A3F, 0x0020},
 {0x011A47, 0x0010},
-{0x011A48, 0x0080},
+{0x011A48, 0x0001},
 {0x011A50, 0x0004},
 {0x011A51, 0x0010},
 {0x011A5C, 0x0004},
@@ -1731,117 +1743,117 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x011A9A, 0x0020},
 {0x011A9D, 0x0004},
 {0x011A9E, 0x0020},
-{0x011AA3, 0x0080},
+{0x011AA3, 0x0001},
 {0x011AB0, 0x0004},
-{0x011AF9, 0x0080},
+{0x011AF9, 0x0001},
 {0x011B00, 0x0020},
-{0x011B0A, 0x0080},
+{0x011B0A, 0x0001},
 {0x011C00, 0x0004},
-{0x011C09, 0x0080},
+{0x011C09, 0x0001},
 {0x011C0A, 0x0004},
 {0x011C2F, 0x0010},
-{0x011C37, 0x0080},
+{0x011C37, 0x0001},
 {0x011C38, 0x0010},
 {0x011C40, 0x0004},
 {0x011C41, 0x0020},
-{0x011C46, 0x0080},
+{0x011C46, 0x0001},
 {0x011C50, 0x0002},
-{0x011C6D, 0x0080},
+{0x011C6D, 0x0001},
 {0x011C70, 0x0020},
 {0x011C72, 0x0004},
-{0x011C90, 0x0080},
+{0x011C90, 0x0001},
 {0x011C92, 0x0010},
-{0x011CA8, 0x0080},
+{0x011CA8, 0x0001},
 {0x011CA9, 0x0010},
-{0x011CB7, 0x0080},
+{0x011CB7, 0x0001},
 {0x011D00, 0x0004},
-{0x011D07, 0x0080},
+{0x011D07, 0x0001},
 {0x011D08, 0x0004},
-{0x011D0A, 0x0080},
+{0x011D0A, 0x0001},
 {0x011D0B, 0x0004},
 {0x011D31, 0x0010},
-{0x011D37, 0x0080},
+{0x011D37, 0x0001},
 {0x011D3A, 0x0010},
-{0x011D3B, 0x0080},
+{0x011D3B, 0x0001},
 {0x011D3C, 0x0010},
-{0x011D3E, 0x0080},
+{0x011D3E, 0x0001},
 {0x011D3F, 0x0010},
 {0x011D46, 0x0004},
 {0x011D47, 0x0010},
-{0x011D48, 0x0080},
+{0x011D48, 0x0001},
 {0x011D50, 0x0002},
-{0x011D5A, 0x0080},
+{0x011D5A, 0x0001},
 {0x011D60, 0x0004},
-{0x011D66, 0x0080},
+{0x011D66, 0x0001},
 {0x011D67, 0x0004},
-{0x011D69, 0x0080},
+{0x011D69, 0x0001},
 {0x011D6A, 0x0004},
 {0x011D8A, 0x0010},
-{0x011D8F, 0x0080},
+{0x011D8F, 0x0001},
 {0x011D90, 0x0010},
-{0x011D92, 0x0080},
+{0x011D92, 0x0001},
 {0x011D93, 0x0010},
 {0x011D98, 0x0004},
-{0x011D99, 0x0080},
+{0x011D99, 0x0001},
 {0x011DA0, 0x0002},
-{0x011DAA, 0x0080},
+{0x011DAA, 0x0001},
 {0x011EE0, 0x0004},
 {0x011EF3, 0x0010},
 {0x011EF7, 0x0020},
-{0x011EF9, 0x0080},
+{0x011EF9, 0x0001},
 {0x011F00, 0x0010},
 {0x011F02, 0x0004},
 {0x011F03, 0x0010},
 {0x011F04, 0x0004},
-{0x011F11, 0x0080},
+{0x011F11, 0x0001},
 {0x011F12, 0x0004},
 {0x011F34, 0x0010},
-{0x011F3B, 0x0080},
+{0x011F3B, 0x0001},
 {0x011F3E, 0x0010},
 {0x011F43, 0x0020},
 {0x011F50, 0x0002},
-{0x011F5A, 0x0080},
+{0x011F5A, 0x0001},
 {0x011FB0, 0x0004},
-{0x011FB1, 0x0080},
+{0x011FB1, 0x0001},
 {0x011FC0, 0x0002},
 {0x011FD5, 0x0040},
-{0x011FF2, 0x0080},
+{0x011FF2, 0x0001},
 {0x011FFF, 0x0020},
 {0x012000, 0x0004},
-{0x01239A, 0x0080},
+{0x01239A, 0x0001},
 {0x012400, 0x0002},
-{0x01246F, 0x0080},
+{0x01246F, 0x0001},
 {0x012470, 0x0020},
-{0x012475, 0x0080},
+{0x012475, 0x0001},
 {0x012480, 0x0004},
-{0x012544, 0x0080},
+{0x012544, 0x0001},
 {0x012F90, 0x0004},
 {0x012FF1, 0x0020},
-{0x012FF3, 0x0080},
+{0x012FF3, 0x0001},
 {0x013000, 0x0004},
 {0x013430, 0x0080},
 {0x013440, 0x0010},
 {0x013441, 0x0004},
 {0x013447, 0x0010},
-{0x013456, 0x0080},
+{0x013456, 0x0001},
 {0x014400, 0x0004},
-{0x014647, 0x0080},
+{0x014647, 0x0001},
 {0x016800, 0x0004},
-{0x016A39, 0x0080},
+{0x016A39, 0x0001},
 {0x016A40, 0x0004},
-{0x016A5F, 0x0080},
+{0x016A5F, 0x0001},
 {0x016A60, 0x0002},
-{0x016A6A, 0x0080},
+{0x016A6A, 0x0001},
 {0x016A6E, 0x0020},
 {0x016A70, 0x0004},
-{0x016ABF, 0x0080},
+{0x016ABF, 0x0001},
 {0x016AC0, 0x0002},
-{0x016ACA, 0x0080},
+{0x016ACA, 0x0001},
 {0x016AD0, 0x0004},
-{0x016AEE, 0x0080},
+{0x016AEE, 0x0001},
 {0x016AF0, 0x0010},
 {0x016AF5, 0x0020},
-{0x016AF6, 0x0080},
+{0x016AF6, 0x0001},
 {0x016B00, 0x0004},
 {0x016B30, 0x0010},
 {0x016B37, 0x0020},
@@ -1849,81 +1861,82 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x016B40, 0x0004},
 {0x016B44, 0x0020},
 {0x016B45, 0x0040},
-{0x016B46, 0x0080},
+{0x016B46, 0x0001},
 {0x016B50, 0x0002},
-{0x016B5A, 0x0080},
+{0x016B5A, 0x0001},
 {0x016B5B, 0x0002},
-{0x016B62, 0x0080},
+{0x016B62, 0x0001},
 {0x016B63, 0x0004},
-{0x016B78, 0x0080},
+{0x016B78, 0x0001},
 {0x016B7D, 0x0004},
-{0x016B90, 0x0080},
+{0x016B90, 0x0001},
 {0x016E40, 0x0004},
 {0x016E80, 0x0002},
 {0x016E97, 0x0020},
-{0x016E9B, 0x0080},
+{0x016E9B, 0x0001},
 {0x016F00, 0x0004},
-{0x016F4B, 0x0080},
+{0x016F4B, 0x0001},
 {0x016F4F, 0x0010},
 {0x016F50, 0x0004},
 {0x016F51, 0x0010},
-{0x016F88, 0x0080},
+{0x016F88, 0x0001},
 {0x016F8F, 0x0010},
 {0x016F93, 0x0004},
-{0x016FA0, 0x0080},
+{0x016FA0, 0x0001},
 {0x016FE0, 0x0004},
 {0x016FE2, 0x0020},
 {0x016FE3, 0x0004},
 {0x016FE4, 0x0010},
-{0x016FE5, 0x0080},
+{0x016FE5, 0x0001},
 {0x016FF0, 0x0010},
-{0x016FF2, 0x0080},
+{0x016FF2, 0x0001},
 {0x017000, 0x0004},
-{0x0187F8, 0x0080},
+{0x0187F8, 0x0001},
 {0x018800, 0x0004},
-{0x018CD6, 0x0080},
+{0x018CD6, 0x0001},
 {0x018D00, 0x0004},
-{0x018D09, 0x0080},
+{0x018D09, 0x0001},
 {0x01AFF0, 0x0004},
-{0x01AFF4, 0x0080},
+{0x01AFF4, 0x0001},
 {0x01AFF5, 0x0004},
-{0x01AFFC, 0x0080},
+{0x01AFFC, 0x0001},
 {0x01AFFD, 0x0004},
-{0x01AFFF, 0x0080},
+{0x01AFFF, 0x0001},
 {0x01B000, 0x0004},
-{0x01B123, 0x0080},
+{0x01B123, 0x0001},
 {0x01B132, 0x0004},
-{0x01B133, 0x0080},
+{0x01B133, 0x0001},
 {0x01B150, 0x0004},
-{0x01B153, 0x0080},
+{0x01B153, 0x0001},
 {0x01B155, 0x0004},
-{0x01B156, 0x0080},
+{0x01B156, 0x0001},
 {0x01B164, 0x0004},
-{0x01B168, 0x0080},
+{0x01B168, 0x0001},
 {0x01B170, 0x0004},
-{0x01B2FC, 0x0080},
+{0x01B2FC, 0x0001},
 {0x01BC00, 0x0004},
-{0x01BC6B, 0x0080},
+{0x01BC6B, 0x0001},
 {0x01BC70, 0x0004},
-{0x01BC7D, 0x0080},
+{0x01BC7D, 0x0001},
 {0x01BC80, 0x0004},
-{0x01BC89, 0x0080},
+{0x01BC89, 0x0001},
 {0x01BC90, 0x0004},
-{0x01BC9A, 0x0080},
+{0x01BC9A, 0x0001},
 {0x01BC9C, 0x0040},
 {0x01BC9D, 0x0010},
 {0x01BC9F, 0x0020},
 {0x01BCA0, 0x0080},
+{0x01BCA4, 0x0001},
 {0x01CF00, 0x0010},
-{0x01CF2E, 0x0080},
+{0x01CF2E, 0x0001},
 {0x01CF30, 0x0010},
-{0x01CF47, 0x0080},
+{0x01CF47, 0x0001},
 {0x01CF50, 0x0040},
-{0x01CFC4, 0x0080},
+{0x01CFC4, 0x0001},
 {0x01D000, 0x0040},
-{0x01D0F6, 0x0080},
+{0x01D0F6, 0x0001},
 {0x01D100, 0x0040},
-{0x01D127, 0x0080},
+{0x01D127, 0x0001},
 {0x01D129, 0x0040},
 {0x01D165, 0x0010},
 {0x01D16A, 0x0040},
@@ -1935,57 +1948,57 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x01D18C, 0x0040},
 {0x01D1AA, 0x0010},
 {0x01D1AE, 0x0040},
-{0x01D1EB, 0x0080},
+{0x01D1EB, 0x0001},
 {0x01D200, 0x0040},
 {0x01D242, 0x0010},
 {0x01D245, 0x0040},
-{0x01D246, 0x0080},
+{0x01D246, 0x0001},
 {0x01D2C0, 0x0002},
-{0x01D2D4, 0x0080},
+{0x01D2D4, 0x0001},
 {0x01D2E0, 0x0002},
-{0x01D2F4, 0x0080},
+{0x01D2F4, 0x0001},
 {0x01D300, 0x0040},
-{0x01D357, 0x0080},
+{0x01D357, 0x0001},
 {0x01D360, 0x0002},
-{0x01D379, 0x0080},
+{0x01D379, 0x0001},
 {0x01D400, 0x0004},
-{0x01D455, 0x0080},
+{0x01D455, 0x0001},
 {0x01D456, 0x0004},
-{0x01D49D, 0x0080},
+{0x01D49D, 0x0001},
 {0x01D49E, 0x0004},
-{0x01D4A0, 0x0080},
+{0x01D4A0, 0x0001},
 {0x01D4A2, 0x0004},
-{0x01D4A3, 0x0080},
+{0x01D4A3, 0x0001},
 {0x01D4A5, 0x0004},
-{0x01D4A7, 0x0080},
+{0x01D4A7, 0x0001},
 {0x01D4A9, 0x0004},
-{0x01D4AD, 0x0080},
+{0x01D4AD, 0x0001},
 {0x01D4AE, 0x0004},
-{0x01D4BA, 0x0080},
+{0x01D4BA, 0x0001},
 {0x01D4BB, 0x0004},
-{0x01D4BC, 0x0080},
+{0x01D4BC, 0x0001},
 {0x01D4BD, 0x0004},
-{0x01D4C4, 0x0080},
+{0x01D4C4, 0x0001},
 {0x01D4C5, 0x0004},
-{0x01D506, 0x0080},
+{0x01D506, 0x0001},
 {0x01D507, 0x0004},
-{0x01D50B, 0x0080},
+{0x01D50B, 0x0001},
 {0x01D50D, 0x0004},
-{0x01D515, 0x0080},
+{0x01D515, 0x0001},
 {0x01D516, 0x0004},
-{0x01D51D, 0x0080},
+{0x01D51D, 0x0001},
 {0x01D51E, 0x0004},
-{0x01D53A, 0x0080},
+{0x01D53A, 0x0001},
 {0x01D53B, 0x0004},
-{0x01D53F, 0x0080},
+{0x01D53F, 0x0001},
 {0x01D540, 0x0004},
-{0x01D545, 0x0080},
+{0x01D545, 0x0001},
 {0x01D546, 0x0004},
-{0x01D547, 0x0080},
+{0x01D547, 0x0001},
 {0x01D54A, 0x0004},
-{0x01D551, 0x0080},
+{0x01D551, 0x0001},
 {0x01D552, 0x0004},
-{0x01D6A6, 0x0080},
+{0x01D6A6, 0x0001},
 {0x01D6A8, 0x0004},
 {0x01D6C1, 0x0040},
 {0x01D6C2, 0x0004},
@@ -2007,7 +2020,7 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x01D7AA, 0x0004},
 {0x01D7C3, 0x0040},
 {0x01D7C4, 0x0004},
-{0x01D7CC, 0x0080},
+{0x01D7CC, 0x0001},
 {0x01D7CE, 0x0002},
 {0x01D800, 0x0040},
 {0x01DA00, 0x0010},
@@ -2019,251 +2032,283 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x01DA84, 0x0010},
 {0x01DA85, 0x0040},
 {0x01DA87, 0x0020},
-{0x01DA8C, 0x0080},
+{0x01DA8C, 0x0001},
 {0x01DA9B, 0x0010},
-{0x01DAA0, 0x0080},
+{0x01DAA0, 0x0001},
 {0x01DAA1, 0x0010},
-{0x01DAB0, 0x0080},
+{0x01DAB0, 0x0001},
 {0x01DF00, 0x0004},
-{0x01DF1F, 0x0080},
+{0x01DF1F, 0x0001},
 {0x01DF25, 0x0004},
-{0x01DF2B, 0x0080},
+{0x01DF2B, 0x0001},
 {0x01E000, 0x0010},
-{0x01E007, 0x0080},
+{0x01E007, 0x0001},
 {0x01E008, 0x0010},
-{0x01E019, 0x0080},
+{0x01E019, 0x0001},
 {0x01E01B, 0x0010},
-{0x01E022, 0x0080},
+{0x01E022, 0x0001},
 {0x01E023, 0x0010},
-{0x01E025, 0x0080},
+{0x01E025, 0x0001},
 {0x01E026, 0x0010},
-{0x01E02B, 0x0080},
+{0x01E02B, 0x0001},
 {0x01E030, 0x0004},
-{0x01E06E, 0x0080},
+{0x01E06E, 0x0001},
 {0x01E08F, 0x0010},
-{0x01E090, 0x0080},
+{0x01E090, 0x0001},
 {0x01E100, 0x0004},
-{0x01E12D, 0x0080},
+{0x01E12D, 0x0001},
 {0x01E130, 0x0010},
 {0x01E137, 0x0004},
-{0x01E13E, 0x0080},
+{0x01E13E, 0x0001},
 {0x01E140, 0x0002},
-{0x01E14A, 0x0080},
+{0x01E14A, 0x0001},
 {0x01E14E, 0x0004},
 {0x01E14F, 0x0040},
-{0x01E150, 0x0080},
+{0x01E150, 0x0001},
 {0x01E290, 0x0004},
 {0x01E2AE, 0x0010},
-{0x01E2AF, 0x0080},
+{0x01E2AF, 0x0001},
 {0x01E2C0, 0x0004},
 {0x01E2EC, 0x0010},
 {0x01E2F0, 0x0002},
-{0x01E2FA, 0x0080},
+{0x01E2FA, 0x0001},
 {0x01E2FF, 0x0040},
-{0x01E300, 0x0080},
+{0x01E300, 0x0001},
 {0x01E4D0, 0x0004},
 {0x01E4EC, 0x0010},
 {0x01E4F0, 0x0002},
-{0x01E4FA, 0x0080},
+{0x01E4FA, 0x0001},
 {0x01E7E0, 0x0004},
-{0x01E7E7, 0x0080},
+{0x01E7E7, 0x0001},
 {0x01E7E8, 0x0004},
-{0x01E7EC, 0x0080},
+{0x01E7EC, 0x0001},
 {0x01E7ED, 0x0004},
-{0x01E7EF, 0x0080},
+{0x01E7EF, 0x0001},
 {0x01E7F0, 0x0004},
-{0x01E7FF, 0x0080},
+{0x01E7FF, 0x0001},
 {0x01E800, 0x0004},
-{0x01E8C5, 0x0080},
+{0x01E8C5, 0x0001},
 {0x01E8C7, 0x0002},
 {0x01E8D0, 0x0010},
-{0x01E8D7, 0x0080},
+{0x01E8D7, 0x0001},
 {0x01E900, 0x0004},
 {0x01E944, 0x0010},
 {0x01E94B, 0x0004},
-{0x01E94C, 0x0080},
+{0x01E94C, 0x0001},
 {0x01E950, 0x0002},
-{0x01E95A, 0x0080},
+{0x01E95A, 0x0001},
 {0x01E95E, 0x0020},
-{0x01E960, 0x0080},
+{0x01E960, 0x0001},
 {0x01EC71, 0x0002},
 {0x01ECAC, 0x0040},
 {0x01ECAD, 0x0002},
 {0x01ECB0, 0x0040},
 {0x01ECB1, 0x0002},
-{0x01ECB5, 0x0080},
+{0x01ECB5, 0x0001},
 {0x01ED01, 0x0002},
 {0x01ED2E, 0x0040},
 {0x01ED2F, 0x0002},
-{0x01ED3E, 0x0080},
+{0x01ED3E, 0x0001},
 {0x01EE00, 0x0004},
-{0x01EE04, 0x0080},
+{0x01EE04, 0x0001},
 {0x01EE05, 0x0004},
-{0x01EE20, 0x0080},
+{0x01EE20, 0x0001},
 {0x01EE21, 0x0004},
-{0x01EE23, 0x0080},
+{0x01EE23, 0x0001},
 {0x01EE24, 0x0004},
-{0x01EE25, 0x0080},
+{0x01EE25, 0x0001},
 {0x01EE27, 0x0004},
-{0x01EE28, 0x0080},
+{0x01EE28, 0x0001},
 {0x01EE29, 0x0004},
-{0x01EE33, 0x0080},
+{0x01EE33, 0x0001},
 {0x01EE34, 0x0004},
-{0x01EE38, 0x0080},
+{0x01EE38, 0x0001},
 {0x01EE39, 0x0004},
-{0x01EE3A, 0x0080},
+{0x01EE3A, 0x0001},
 {0x01EE3B, 0x0004},
-{0x01EE3C, 0x0080},
+{0x01EE3C, 0x0001},
 {0x01EE42, 0x0004},
-{0x01EE43, 0x0080},
+{0x01EE43, 0x0001},
 {0x01EE47, 0x0004},
-{0x01EE48, 0x0080},
+{0x01EE48, 0x0001},
 {0x01EE49, 0x0004},
-{0x01EE4A, 0x0080},
+{0x01EE4A, 0x0001},
 {0x01EE4B, 0x0004},
-{0x01EE4C, 0x0080},
+{0x01EE4C, 0x0001},
 {0x01EE4D, 0x0004},
-{0x01EE50, 0x0080},
+{0x01EE50, 0x0001},
 {0x01EE51, 0x0004},
-{0x01EE53, 0x0080},
+{0x01EE53, 0x0001},
 {0x01EE54, 0x0004},
-{0x01EE55, 0x0080},
+{0x01EE55, 0x0001},
 {0x01EE57, 0x0004},
-{0x01EE58, 0x0080},
+{0x01EE58, 0x0001},
 {0x01EE59, 0x0004},
-{0x01EE5A, 0x0080},
+{0x01EE5A, 0x0001},
 {0x01EE5B, 0x0004},
-{0x01EE5C, 0x0080},
+{0x01EE5C, 0x0001},
 {0x01EE5D, 0x0004},
-{0x01EE5E, 0x0080},
+{0x01EE5E, 0x0001},
 {0x01EE5F, 0x0004},
-{0x01EE60, 0x0080},
+{0x01EE60, 0x0001},
 {0x01EE61, 0x0004},
-{0x01EE63, 0x0080},
+{0x01EE63, 0x0001},
 {0x01EE64, 0x0004},
-{0x01EE65, 0x0080},
+{0x01EE65, 0x0001},
 {0x01EE67, 0x0004},
-{0x01EE6B, 0x0080},
+{0x01EE6B, 0x0001},
 {0x01EE6C, 0x0004},
-{0x01EE73, 0x0080},
+{0x01EE73, 0x0001},
 {0x01EE74, 0x0004},
-{0x01EE78, 0x0080},
+{0x01EE78, 0x0001},
 {0x01EE79, 0x0004},
-{0x01EE7D, 0x0080},
+{0x01EE7D, 0x0001},
 {0x01EE7E, 0x0004},
-{0x01EE7F, 0x0080},
+{0x01EE7F, 0x0001},
 {0x01EE80, 0x0004},
-{0x01EE8A, 0x0080},
+{0x01EE8A, 0x0001},
 {0x01EE8B, 0x0004},
-{0x01EE9C, 0x0080},
+{0x01EE9C, 0x0001},
 {0x01EEA1, 0x0004},
-{0x01EEA4, 0x0080},
+{0x01EEA4, 0x0001},
 {0x01EEA5, 0x0004},
-{0x01EEAA, 0x0080},
+{0x01EEAA, 0x0001},
 {0x01EEAB, 0x0004},
-{0x01EEBC, 0x0080},
+{0x01EEBC, 0x0001},
 {0x01EEF0, 0x0040},
-{0x01EEF2, 0x0080},
+{0x01EEF2, 0x0001},
 {0x01F000, 0x0040},
-{0x01F02C, 0x0080},
+{0x01F02C, 0x0001},
 {0x01F030, 0x0040},
-{0x01F094, 0x0080},
+{0x01F094, 0x0001},
 {0x01F0A0, 0x0040},
-{0x01F0AF, 0x0080},
+{0x01F0AF, 0x0001},
 {0x01F0B1, 0x0040},
-{0x01F0C0, 0x0080},
+{0x01F0C0, 0x0001},
 {0x01F0C1, 0x0040},
-{0x01F0D0, 0x0080},
+{0x01F0D0, 0x0001},
 {0x01F0D1, 0x0040},
-{0x01F0F6, 0x0080},
+{0x01F0F6, 0x0001},
 {0x01F100, 0x0002},
 {0x01F10D, 0x0040},
-{0x01F1AE, 0x0080},
+{0x01F1AE, 0x0001},
 {0x01F1E6, 0x0040},
-{0x01F203, 0x0080},
+{0x01F203, 0x0001},
 {0x01F210, 0x0040},
-{0x01F23C, 0x0080},
+{0x01F23C, 0x0001},
 {0x01F240, 0x0040},
-{0x01F249, 0x0080},
+{0x01F249, 0x0001},
 {0x01F250, 0x0040},
-{0x01F252, 0x0080},
+{0x01F252, 0x0001},
 {0x01F260, 0x0040},
-{0x01F266, 0x0080},
+{0x01F266, 0x0001},
 {0x01F300, 0x0040},
-{0x01F6D8, 0x0080},
+{0x01F6D8, 0x0001},
 {0x01F6DC, 0x0040},
-{0x01F6ED, 0x0080},
+{0x01F6ED, 0x0001},
 {0x01F6F0, 0x0040},
-{0x01F6FD, 0x0080},
+{0x01F6FD, 0x0001},
 {0x01F700, 0x0040},
-{0x01F777, 0x0080},
+{0x01F777, 0x0001},
 {0x01F77B, 0x0040},
-{0x01F7DA, 0x0080},
+{0x01F7DA, 0x0001},
 {0x01F7E0, 0x0040},
-{0x01F7EC, 0x0080},
+{0x01F7EC, 0x0001},
 {0x01F7F0, 0x0040},
-{0x01F7F1, 0x0080},
+{0x01F7F1, 0x0001},
 {0x01F800, 0x0040},
-{0x01F80C, 0x0080},
+{0x01F80C, 0x0001},
 {0x01F810, 0x0040},
-{0x01F848, 0x0080},
+{0x01F848, 0x0001},
 {0x01F850, 0x0040},
-{0x01F85A, 0x0080},
+{0x01F85A, 0x0001},
 {0x01F860, 0x0040},
-{0x01F888, 0x0080},
+{0x01F888, 0x0001},
 {0x01F890, 0x0040},
-{0x01F8AE, 0x0080},
+{0x01F8AE, 0x0001},
 {0x01F8B0, 0x0040},
-{0x01F8B2, 0x0080},
+{0x01F8B2, 0x0001},
 {0x01F900, 0x0040},
-{0x01FA54, 0x0080},
+{0x01FA54, 0x0001},
 {0x01FA60, 0x0040},
-{0x01FA6E, 0x0080},
+{0x01FA6E, 0x0001},
 {0x01FA70, 0x0040},
-{0x01FA7D, 0x0080},
+{0x01FA7D, 0x0001},
 {0x01FA80, 0x0040},
-{0x01FA89, 0x0080},
+{0x01FA89, 0x0001},
 {0x01FA90, 0x0040},
-{0x01FABE, 0x0080},
+{0x01FABE, 0x0001},
 {0x01FABF, 0x0040},
-{0x01FAC6, 0x0080},
+{0x01FAC6, 0x0001},
 {0x01FACE, 0x0040},
-{0x01FADC, 0x0080},
+{0x01FADC, 0x0001},
 {0x01FAE0, 0x0040},
-{0x01FAE9, 0x0080},
+{0x01FAE9, 0x0001},
 {0x01FAF0, 0x0040},
-{0x01FAF9, 0x0080},
+{0x01FAF9, 0x0001},
 {0x01FB00, 0x0040},
-{0x01FB93, 0x0080},
+{0x01FB93, 0x0001},
 {0x01FB94, 0x0040},
-{0x01FBCB, 0x0080},
+{0x01FBCB, 0x0001},
 {0x01FBF0, 0x0002},
-{0x01FBFA, 0x0080},
+{0x01FBFA, 0x0001},
 {0x020000, 0x0004},
-{0x02A6E0, 0x0080},
+{0x02A6E0, 0x0001},
 {0x02A700, 0x0004},
-{0x02B73A, 0x0080},
+{0x02B73A, 0x0001},
 {0x02B740, 0x0004},
-{0x02B81E, 0x0080},
+{0x02B81E, 0x0001},
 {0x02B820, 0x0004},
-{0x02CEA2, 0x0080},
+{0x02CEA2, 0x0001},
 {0x02CEB0, 0x0004},
-{0x02EBE1, 0x0080},
+{0x02EBE1, 0x0001},
 {0x02EBF0, 0x0004},
-{0x02EE5E, 0x0080},
+{0x02EE5E, 0x0001},
 {0x02F800, 0x0004},
-{0x02FA1E, 0x0080},
+{0x02FA1E, 0x0001},
 {0x030000, 0x0004},
-{0x03134B, 0x0080},
+{0x03134B, 0x0001},
 {0x031350, 0x0004},
-{0x0323B0, 0x0080},
+{0x0323B0, 0x0001},
+{0x0E0001, 0x0080},
+{0x0E0002, 0x0001},
+{0x0E0020, 0x0080},
+{0x0E0080, 0x0001},
 {0x0E0100, 0x0010},
-{0x0E01F0, 0x0080},
+{0x0E01F0, 0x0001},
+{0x0F0000, 0x0080},
+{0x0FFFFE, 0x0001},
+{0x100000, 0x0080},
+{0x10FFFE, 0x0001},
 {0x110000, 0x0000},
 };
 
 const std::unordered_set unicode_set_whitespace = {
-0x000009, 0x00000A, 0x00000B, 0x00000C, 0x00000D, 0x000020, 0x000085, 0x0000A0, 0x001680, 0x002000, 0x002001, 0x002002, 0x002003, 0x002004, 0x002005, 0x002006, 0x002007, 0x002008, 0x002009, 0x00200A, 0x002028, 0x002029, 0x00202F, 0x00205F, 0x003000
+0x000009,
+0x00000A,
+0x00000B,
+0x00000C,
+0x00000D,
+0x000020,
+0x000085,
+0x0000A0,
+0x001680,
+0x002000,
+0x002001,
+0x002002,
+0x002003,
+0x002004,
+0x002005,
+0x002006,
+0x002007,
+0x002008,
+0x002009,
+0x00200A,
+0x002028,
+0x002029,
+0x00202F,
+0x00205F,
+0x003000,
 };
 
 const std::unordered_map unicode_map_lowercase = {
@@ -3222,6 +3267,7 @@ const std::unordered_map unicode_map_lowercase = {
 {0x002C2C, 0x002C5C},
 {0x002C2D, 0x002C5D},
 {0x002C2E, 0x002C5E},
+{0x002C2F, 0x002C5F},
 {0x002C60, 0x002C61},
 {0x002C62, 0x00026B},
 {0x002C63, 0x001D7D},
@@ -3402,12 +3448,16 @@ const std::unordered_map unicode_map_lowercase = {
 {0x00A7BA, 0x00A7BB},
 {0x00A7BC, 0x00A7BD},
 {0x00A7BE, 0x00A7BF},
+{0x00A7C0, 0x00A7C1},
 {0x00A7C2, 0x00A7C3},
 {0x00A7C4, 0x00A794},
 {0x00A7C5, 0x000282},
 {0x00A7C6, 0x001D8E},
 {0x00A7C7, 0x00A7C8},
 {0x00A7C9, 0x00A7CA},
+{0x00A7D0, 0x00A7D1},
+{0x00A7D6, 0x00A7D7},
+{0x00A7D8, 0x00A7D9},
 {0x00A7F5, 0x00A7F6},
 {0x00FF21, 0x00FF41},
 {0x00FF22, 0x00FF42},
@@ -3511,6 +3561,41 @@ const std::unordered_map unicode_map_lowercase = {
 {0x0104D1, 0x0104F9},
 {0x0104D2, 0x0104FA},
 {0x0104D3, 0x0104FB},
+{0x010570, 0x010597},
+{0x010571, 0x010598},
+{0x010572, 0x010599},
+{0x010573, 0x01059A},
+{0x010574, 0x01059B},
+{0x010575, 0x01059C},
+{0x010576, 0x01059D},
+{0x010577, 0x01059E},
+{0x010578, 0x01059F},
+{0x010579, 0x0105A0},
+{0x01057A, 0x0105A1},
+{0x01057C, 0x0105A3},
+{0x01057D, 0x0105A4},
+{0x01057E, 0x0105A5},
+{0x01057F, 0x0105A6},
+{0x010580, 0x0105A7},
+{0x010581, 0x0105A8},
+{0x010582, 0x0105A9},
+{0x010583, 0x0105AA},
+{0x010584, 0x0105AB},
+{0x010585, 0x0105AC},
+{0x010586, 0x0105AD},
+{0x010587, 0x0105AE},
+{0x010588, 0x0105AF},
+{0x010589, 0x0105B0},
+{0x01058A, 0x0105B1},
+{0x01058C, 0x0105B3},
+{0x01058D, 0x0105B4},
+{0x01058E, 0x0105B5},
+{0x01058F, 0x0105B6},
+{0x010590, 0x0105B7},
+{0x010591, 0x0105B8},
+{0x010592, 0x0105B9},
+{0x010594, 0x0105BB},
+{0x010595, 0x0105BC},
 {0x010C80, 0x010CC0},
 {0x010C81, 0x010CC1},
 {0x010C82, 0x010CC2},
@@ -3690,7 +3775,6 @@ const std::unordered_map unicode_map_uppercase = {
 {0x000079, 0x000059},
 {0x00007A, 0x00005A},
 {0x0000B5, 0x00039C},
-{0x0000DF, 0x000053},
 {0x0000E0, 0x0000C0},
 {0x0000E1, 0x0000C1},
 {0x0000E2, 0x0000C2},
@@ -3758,7 +3842,6 @@ const std::unordered_map unicode_map_uppercase = {
 {0x000144, 0x000143},
 {0x000146, 0x000145},
 {0x000148, 0x000147},
-{0x000149, 0x0002BC},
 {0x00014B, 0x00014A},
 {0x00014D, 0x00014C},
 {0x00014F, 0x00014E},
@@ -3831,7 +3914,6 @@ const std::unordered_map unicode_map_uppercase = {
 {0x0001EB, 0x0001EA},
 {0x0001ED, 0x0001EC},
 {0x0001EF, 0x0001EE},
-{0x0001F0, 0x00004A},
 {0x0001F2, 0x0001F1},
 {0x0001F3, 0x0001F1},
 {0x0001F5, 0x0001F4},
@@ -3917,12 +3999,10 @@ const std::unordered_map unicode_map_uppercase = {
 {0x00037B, 0x0003FD},
 {0x00037C, 0x0003FE},
 {0x00037D, 0x0003FF},
-{0x000390, 0x000399},
 {0x0003AC, 0x000386},
 {0x0003AD, 0x000388},
 {0x0003AE, 0x000389},
 {0x0003AF, 0x00038A},
-{0x0003B0, 0x0003A5},
 {0x0003B1, 0x000391},
 {0x0003B2, 0x000392},
 {0x0003B3, 0x000393},
@@ -4163,7 +4243,6 @@ const std::unordered_map unicode_map_uppercase = {
 {0x000584, 0x000554},
 {0x000585, 0x000555},
 {0x000586, 0x000556},
-{0x000587, 0x000535},
 {0x0010D0, 0x001C90},
 {0x0010D1, 0x001C91},
 {0x0010D2, 0x001C92},
@@ -4303,11 +4382,6 @@ const std::unordered_map unicode_map_uppercase = {
 {0x001E91, 0x001E90},
 {0x001E93, 0x001E92},
 {0x001E95, 0x001E94},
-{0x001E96, 0x000048},
-{0x001E97, 0x000054},
-{0x001E98, 0x000057},
-{0x001E99, 0x000059},
-{0x001E9A, 0x000041},
 {0x001E9B, 0x001E60},
 {0x001EA1, 0x001EA0},
 {0x001EA3, 0x001EA2},
@@ -4393,13 +4467,9 @@ const std::unordered_map unicode_map_uppercase = {
 {0x001F43, 0x001F4B},
 {0x001F44, 0x001F4C},
 {0x001F45, 0x001F4D},
-{0x001F50, 0x0003A5},
 {0x001F51, 0x001F59},
-{0x001F52, 0x0003A5},
 {0x001F53, 0x001F5B},
-{0x001F54, 0x0003A5},
 {0x001F55, 0x001F5D},
-{0x001F56, 0x0003A5},
 {0x001F57, 0x001F5F},
 {0x001F60, 0x001F68},
 {0x001F61, 0x001F69},
@@ -4423,89 +4493,41 @@ const std::unordered_map unicode_map_uppercase = {
 {0x001F7B, 0x001FEB},
 {0x001F7C, 0x001FFA},
 {0x001F7D, 0x001FFB},
-{0x001F80, 0x001F08},
-{0x001F81, 0x001F09},
-{0x001F82, 0x001F0A},
-{0x001F83, 0x001F0B},
-{0x001F84, 0x001F0C},
-{0x001F85, 0x001F0D},
-{0x001F86, 0x001F0E},
-{0x001F87, 0x001F0F},
-{0x001F88, 0x001F08},
-{0x001F89, 0x001F09},
-{0x001F8A, 0x001F0A},
-{0x001F8B, 0x001F0B},
-{0x001F8C, 0x001F0C},
-{0x001F8D, 0x001F0D},
-{0x001F8E, 0x001F0E},
-{0x001F8F, 0x001F0F},
-{0x001F90, 0x001F28},
-{0x001F91, 0x001F29},
-{0x001F92, 0x001F2A},
-{0x001F93, 0x001F2B},
-{0x001F94, 0x001F2C},
-{0x001F95, 0x001F2D},
-{0x001F96, 0x001F2E},
-{0x001F97, 0x001F2F},
-{0x001F98, 0x001F28},
-{0x001F99, 0x001F29},
-{0x001F9A, 0x001F2A},
-{0x001F9B, 0x001F2B},
-{0x001F9C, 0x001F2C},
-{0x001F9D, 0x001F2D},
-{0x001F9E, 0x001F2E},
-{0x001F9F, 0x001F2F},
-{0x001FA0, 0x001F68},
-{0x001FA1, 0x001F69},
-{0x001FA2, 0x001F6A},
-{0x001FA3, 0x001F6B},
-{0x001FA4, 0x001F6C},
-{0x001FA5, 0x001F6D},
-{0x001FA6, 0x001F6E},
-{0x001FA7, 0x001F6F},
-{0x001FA8, 0x001F68},
-{0x001FA9, 0x001F69},
-{0x001FAA, 0x001F6A},
-{0x001FAB, 0x001F6B},
-{0x001FAC, 0x001F6C},
-{0x001FAD, 0x001F6D},
-{0x001FAE, 0x001F6E},
-{0x001FAF, 0x001F6F},
+{0x001F80, 0x001F88},
+{0x001F81, 0x001F89},
+{0x001F82, 0x001F8A},
+{0x001F83, 0x001F8B},
+{0x001F84, 0x001F8C},
+{0x001F85, 0x001F8D},
+{0x001F86, 0x001F8E},
+{0x001F87, 0x001F8F},
+{0x001F90, 0x001F98},
+{0x001F91, 0x001F99},
+{0x001F92, 0x001F9A},
+{0x001F93, 0x001F9B},
+{0x001F94, 0x001F9C},
+{0x001F95, 0x001F9D},
+{0x001F96, 0x001F9E},
+{0x001F97, 0x001F9F},
+{0x001FA0, 0x001FA8},
+{0x001FA1, 0x001FA9},
+{0x001FA2, 0x001FAA},
+{0x001FA3, 0x001FAB},
+{0x001FA4, 0x001FAC},
+{0x001FA5, 0x001FAD},
+{0x001FA6, 0x001FAE},
+{0x001FA7, 0x001FAF},
 {0x001FB0, 0x001FB8},
 {0x001FB1, 0x001FB9},
-{0x001FB2, 0x001FBA},
-{0x001FB3, 0x000391},
-{0x001FB4, 0x000386},
-{0x001FB6, 0x000391},
-{0x001FB7, 0x000391},
-{0x001FBC, 0x000391},
+{0x001FB3, 0x001FBC},
 {0x001FBE, 0x000399},
-{0x001FC2, 0x001FCA},
-{0x001FC3, 0x000397},
-{0x001FC4, 0x000389},
-{0x001FC6, 0x000397},
-{0x001FC7, 0x000397},
-{0x001FCC, 0x000397},
+{0x001FC3, 0x001FCC},
 {0x001FD0, 0x001FD8},
 {0x001FD1, 0x001FD9},
-{0x001FD2, 0x000399},
-{0x001FD3, 0x000399},
-{0x001FD6, 0x000399},
-{0x001FD7, 0x000399},
 {0x001FE0, 0x001FE8},
 {0x001FE1, 0x001FE9},
-{0x001FE2, 0x0003A5},
-{0x001FE3, 0x0003A5},
-{0x001FE4, 0x0003A1},
 {0x001FE5, 0x001FEC},
-{0x001FE6, 0x0003A5},
-{0x001FE7, 0x0003A5},
-{0x001FF2, 0x001FFA},
-{0x001FF3, 0x0003A9},
-{0x001FF4, 0x00038F},
-{0x001FF6, 0x0003A9},
-{0x001FF7, 0x0003A9},
-{0x001FFC, 0x0003A9},
+{0x001FF3, 0x001FFC},
 {0x00214E, 0x002132},
 {0x002170, 0x002160},
 {0x002171, 0x002161},
@@ -4597,6 +4619,7 @@ const std::unordered_map unicode_map_uppercase = {
 {0x002C5C, 0x002C2C},
 {0x002C5D, 0x002C2D},
 {0x002C5E, 0x002C2E},
+{0x002C5F, 0x002C2F},
 {0x002C61, 0x002C60},
 {0x002C65, 0x00023A},
 {0x002C66, 0x00023E},
@@ -4800,9 +4823,13 @@ const std::unordered_map unicode_map_uppercase = {
 {0x00A7BB, 0x00A7BA},
 {0x00A7BD, 0x00A7BC},
 {0x00A7BF, 0x00A7BE},
+{0x00A7C1, 0x00A7C0},
 {0x00A7C3, 0x00A7C2},
 {0x00A7C8, 0x00A7C7},
 {0x00A7CA, 0x00A7C9},
+{0x00A7D1, 0x00A7D0},
+{0x00A7D7, 0x00A7D6},
+{0x00A7D9, 0x00A7D8},
 {0x00A7F6, 0x00A7F5},
 {0x00AB53, 0x00A7B3},
 {0x00AB70, 0x0013A0},
@@ -4885,18 +4912,6 @@ const std::unordered_map unicode_map_uppercase = {
 {0x00ABBD, 0x0013ED},
 {0x00ABBE, 0x0013EE},
 {0x00ABBF, 0x0013EF},
-{0x00FB00, 0x000046},
-{0x00FB01, 0x000046},
-{0x00FB02, 0x000046},
-{0x00FB03, 0x000046},
-{0x00FB04, 0x000046},
-{0x00FB05, 0x000053},
-{0x00FB06, 0x000053},
-{0x00FB13, 0x000544},
-{0x00FB14, 0x000544},
-{0x00FB15, 0x000544},
-{0x00FB16, 0x00054E},
-{0x00FB17, 0x000544},
 {0x00FF41, 0x00FF21},
 {0x00FF42, 0x00FF22},
 {0x00FF43, 0x00FF23},
@@ -4999,6 +5014,41 @@ const std::unordered_map unicode_map_uppercase = {
 {0x0104F9, 0x0104D1},
 {0x0104FA, 0x0104D2},
 {0x0104FB, 0x0104D3},
+{0x010597, 0x010570},
+{0x010598, 0x010571},
+{0x010599, 0x010572},
+{0x01059A, 0x010573},
+{0x01059B, 0x010574},
+{0x01059C, 0x010575},
+{0x01059D, 0x010576},
+{0x01059E, 0x010577},
+{0x01059F, 0x010578},
+{0x0105A0, 0x010579},
+{0x0105A1, 0x01057A},
+{0x0105A3, 0x01057C},
+{0x0105A4, 0x01057D},
+{0x0105A5, 0x01057E},
+{0x0105A6, 0x01057F},
+{0x0105A7, 0x010580},
+{0x0105A8, 0x010581},
+{0x0105A9, 0x010582},
+{0x0105AA, 0x010583},
+{0x0105AB, 0x010584},
+{0x0105AC, 0x010585},
+{0x0105AD, 0x010586},
+{0x0105AE, 0x010587},
+{0x0105AF, 0x010588},
+{0x0105B0, 0x010589},
+{0x0105B1, 0x01058A},
+{0x0105B3, 0x01058C},
+{0x0105B4, 0x01058D},
+{0x0105B5, 0x01058E},
+{0x0105B6, 0x01058F},
+{0x0105B7, 0x010590},
+{0x0105B8, 0x010591},
+{0x0105B9, 0x010592},
+{0x0105BB, 0x010594},
+{0x0105BC, 0x010595},
 {0x010CC0, 0x010C80},
 {0x010CC1, 0x010C81},
 {0x010CC2, 0x010C82},
@@ -6980,4 +7030,3 @@ const std::vector unicode_ranges_nfd = {  // start, last, nfd
 {0x02FA1C, 0x02FA1C, 0x009F3B},
 {0x02FA1D, 0x02FA1D, 0x02A600},
 };
-
diff --git a/cpp/unicode.cpp b/cpp/unicode.cpp
index 056a4c74..e05fb9d1 100644
--- a/cpp/unicode.cpp
+++ b/cpp/unicode.cpp
@@ -1,3 +1,7 @@
+#if defined(_MSC_VER)
+#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
+#endif
+
 #include "unicode.h"
 #include "unicode-data.h"
 
@@ -23,7 +27,7 @@ static std::string unicode_cpts_to_utf8(const std::vector & cps) {
     return result;
 }
 
-static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
+uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
     assert(offset < utf8.size());
     if (!(utf8[offset + 0] & 0x80)) {
         auto result = utf8[offset + 0];
@@ -226,13 +230,13 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t
         assert(offset_end <= cpts.size());
         start = offset_end;
 
-        auto _get_cpt = [&] (const size_t pos) -> char32_t {
-            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
+        static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
+        auto _get_cpt = [&] (const size_t pos) -> uint32_t {
+            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
         };
 
         auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
-            static const codepoint_flags undef(codepoint_flags::UNDEFINED);
-            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
         };
 
         size_t _prev_end = offset_ini;
@@ -253,18 +257,18 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t
         };
 
         for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
-            const char32_t cpt = _get_cpt(pos);
+            const uint32_t cpt = _get_cpt(pos);
             const auto flags = _get_flags(pos);
 
             // regex: 's|'t|'re|'ve|'m|'ll|'d
             if (cpt == '\'' && pos+1 < offset_end) {
-                char32_t cpt_next = _get_cpt(pos+1);
+                uint32_t cpt_next = _get_cpt(pos+1);
                 if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
                     pos += _add_token(pos+2);
                     continue;
                 }
                 if (pos+2 < offset_end) {
-                    char32_t cpt_next_next = _get_cpt(pos+2);
+                    uint32_t cpt_next_next = _get_cpt(pos+2);
                     if ((cpt_next == 'r' && cpt_next_next == 'e') ||
                         (cpt_next == 'v' && cpt_next_next == 'e') ||
                         (cpt_next == 'l' && cpt_next_next == 'l')) {
@@ -294,9 +298,9 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t
                 continue;
             }
             // regex: ?[^\s\p{L}\p{N}]+
-            if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
+            if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
                 pos += (cpt == ' ');
-                while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
+                while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
                     flags2 = _get_flags(++pos);
                 }
                 _add_token(pos);
@@ -309,7 +313,7 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t
             }
 
             // regex: \s+(?!\S)
-            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
+            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
                 pos += num_whitespaces - 1;
                 _add_token(pos);
                 continue;
@@ -344,13 +348,13 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
         assert(offset_end <= cpts.size());
         start = offset_end;
 
-        auto _get_cpt = [&] (const size_t pos) -> char32_t {
-            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
+        static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
+        auto _get_cpt = [&] (const size_t pos) -> uint32_t {
+            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
         };
 
         auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
-            static const codepoint_flags undef(codepoint_flags::UNDEFINED);
-            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
         };
 
         size_t _prev_end = offset_ini;
@@ -371,18 +375,18 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
         };
 
         for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
-            const char32_t cpt = _get_cpt(pos);
+            const uint32_t cpt = _get_cpt(pos);
             const auto flags = _get_flags(pos);
 
             // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
             if (cpt == '\'' && pos+1 < offset_end) {
-                char32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
+                uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
                 if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
                     pos += _add_token(pos+2);
                     continue;
                 }
                 if (pos+2 < offset_end) {
-                    char32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
+                    uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
                     if ((cpt_next == 'r' && cpt_next_next == 'e') ||
                         (cpt_next == 'v' && cpt_next_next == 'e') ||
                         (cpt_next == 'l' && cpt_next_next == 'l')) {
@@ -392,8 +396,8 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
                 }
             }
 
-            // regex: [^\r\n\p{L}\p{N}]?\p{L}+  //####FIXME: the first \p{L} is correct?
-            if (!(cpt == '\r' || cpt == '\n' || /*flags.is_letter |*/ flags.is_number)) {
+            // regex: [^\r\n\p{L}\p{N}]?\p{L}+
+            if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
                 if (flags.is_letter || _get_flags(pos+1).is_letter) {  // one or more letters
                     pos++;
                     while (_get_flags(pos).is_letter) {
@@ -419,12 +423,12 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
 
             // regex: ?[^\s\p{L}\p{N}]+[\r\n]*
             auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
-            if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
+            if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
                 pos += (cpt == ' ');
-                while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
+                while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
                     flags2 = _get_flags(++pos);
                 }
-                char32_t cpt2 = _get_cpt(pos);
+                uint32_t cpt2 = _get_cpt(pos);
                 while (cpt2 == '\r' || cpt2 == '\n') {
                     cpt2 = _get_cpt(++pos);
                 }
@@ -435,7 +439,7 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
             size_t num_whitespaces = 0;
             size_t last_end_r_or_n = 0;
             while (_get_flags(pos+num_whitespaces).is_whitespace) {
-                char32_t cpt2 = _get_cpt(pos+num_whitespaces);
+                uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
                 if (cpt2 == '\r' || cpt2 == '\n') {
                     last_end_r_or_n = pos + num_whitespaces + 1;
                 }
@@ -450,7 +454,7 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
             }
 
             // regex: \s+(?!\S)
-            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
+            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
                 pos += num_whitespaces - 1;
                 _add_token(pos);
                 continue;
@@ -594,6 +598,7 @@ std::vector unicode_cpts_normalize_nfd(const std::vector & c
 
 std::vector unicode_cpts_from_utf8(const std::string & utf8) {
     std::vector result;
+    result.reserve(utf8.size());
     size_t offset = 0;
     while (offset < utf8.size()) {
         result.push_back(unicode_cpt_from_utf8(utf8, offset));
@@ -626,7 +631,7 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8) {
     return map.at(utf8);
 }
 
-char32_t unicode_tolower(char32_t cp) {
+uint32_t unicode_tolower(uint32_t cp) {
     auto it = unicode_map_lowercase.find(cp);
     return it == unicode_map_lowercase.end() ? cp : it->second;
 }
@@ -679,10 +684,14 @@ std::vector unicode_regex_split(const std::string & text, const std
                 continue;
             }
 
-            const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
+            const auto flags = unicode_cpt_flags(cpts[i]);
 
-            if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
-                text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
+            if (flags.is_whitespace) {
+                //NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
+                //text_collapsed[i] = (char) 0x85;  //  as whitespace fallback
+                text_collapsed[i] = (char) 0x0B;    //  as whitespace fallback
+            } else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
+                text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
             } else {
                 text_collapsed[i] = (char) 0xD0; // fallback
             }
@@ -766,9 +775,16 @@ std::vector unicode_regex_split(const std::string & text, const std
                 bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
             } else {
                 // no unicode category used, we can use std::wregex directly
-                const std::wstring wtext       = unicode_wstring_from_utf8(text);
                 const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
 
+                // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
+                std::wstring wtext(cpts.begin(), cpts.end());
+                for (size_t i = 0; i < wtext.size(); ++i) {
+                    if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
+                        wtext[i] = 0x0B;
+                    }
+                }
+
                 //printf("text: %s\n", text.c_str());
                 //printf("regex_expr: %s\n", regex_expr.c_str());
                 bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
diff --git a/cpp/unicode.h b/cpp/unicode.h
index 7513be4a..30b07ba7 100644
--- a/cpp/unicode.h
+++ b/cpp/unicode.h
@@ -48,6 +48,7 @@ struct codepoint_flags {
 
 
 std::string unicode_cpt_to_utf8(uint32_t cp);
+uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
 std::vector unicode_cpts_from_utf8(const std::string & utf8);
 
 std::vector unicode_cpts_normalize_nfd(const std::vector & cpts);
@@ -58,6 +59,6 @@ codepoint_flags unicode_cpt_flags(const std::string & utf8);
 std::string unicode_byte_to_utf8(uint8_t byte);
 uint8_t unicode_utf8_to_byte(const std::string & utf8);
 
-char32_t unicode_tolower(char32_t cp);
+uint32_t unicode_tolower(uint32_t cp);
 
 std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs);
diff --git a/llama.cpp b/llama.cpp
index b864b50c..c3776cac 160000
--- a/llama.cpp
+++ b/llama.cpp
@@ -1 +1 @@
-Subproject commit b864b50ce5e2beefc8c2fd31733e4e1a978b7754
+Subproject commit c3776cacabce2ee35f172fb72be7a519752125fa
diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh
index 47cf44bb..8abbcec0 100755
--- a/scripts/bootstrap.sh
+++ b/scripts/bootstrap.sh
@@ -3,27 +3,27 @@
 git submodule init
 git submodule update --recursive
 
-cp ./llama.cpp/ggml.h ./cpp/ggml.h
-cp ./llama.cpp/ggml.c ./cpp/ggml.c
-cp ./llama.cpp/ggml-metal.h ./cpp/ggml-metal.h
-cp ./llama.cpp/ggml-metal.m ./cpp/ggml-metal.m
-cp ./llama.cpp/ggml-alloc.h ./cpp/ggml-alloc.h
-cp ./llama.cpp/ggml-alloc.c ./cpp/ggml-alloc.c
-cp ./llama.cpp/ggml-backend.h ./cpp/ggml-backend.h
-cp ./llama.cpp/ggml-backend.c ./cpp/ggml-backend.c
-cp ./llama.cpp/ggml-backend-impl.h ./cpp/ggml-backend-impl.h
-cp ./llama.cpp/ggml-impl.h ./cpp/ggml-impl.h
-cp ./llama.cpp/ggml-common.h ./cpp/ggml-common.h
-cp ./llama.cpp/llama.h ./cpp/llama.h
-cp ./llama.cpp/llama.cpp ./cpp/llama.cpp
-cp ./llama.cpp/ggml-quants.h ./cpp/ggml-quants.h
-cp ./llama.cpp/ggml-quants.c ./cpp/ggml-quants.c
-cp ./llama.cpp/unicode.h ./cpp/unicode.h
-cp ./llama.cpp/unicode.cpp ./cpp/unicode.cpp
-cp ./llama.cpp/unicode-data.h ./cpp/unicode-data.h
-cp ./llama.cpp/unicode-data.cpp ./cpp/unicode-data.cpp
-cp ./llama.cpp/sgemm.h ./cpp/sgemm.h
-cp ./llama.cpp/sgemm.cpp ./cpp/sgemm.cpp
+cp ./llama.cpp/ggml/include/ggml.h ./cpp/ggml.h
+cp ./llama.cpp/ggml/src/ggml.c ./cpp/ggml.c
+cp ./llama.cpp/ggml/include/ggml-metal.h ./cpp/ggml-metal.h
+cp ./llama.cpp/ggml/src/ggml-metal.m ./cpp/ggml-metal.m
+cp ./llama.cpp/ggml/include/ggml-alloc.h ./cpp/ggml-alloc.h
+cp ./llama.cpp/ggml/src/ggml-alloc.c ./cpp/ggml-alloc.c
+cp ./llama.cpp/ggml/include/ggml-backend.h ./cpp/ggml-backend.h
+cp ./llama.cpp/ggml/src/ggml-backend.c ./cpp/ggml-backend.c
+cp ./llama.cpp/ggml/src/ggml-backend-impl.h ./cpp/ggml-backend-impl.h
+cp ./llama.cpp/ggml/src/ggml-impl.h ./cpp/ggml-impl.h
+cp ./llama.cpp/ggml/src/ggml-common.h ./cpp/ggml-common.h
+cp ./llama.cpp/include/llama.h ./cpp/llama.h
+cp ./llama.cpp/src/llama.cpp ./cpp/llama.cpp
+cp ./llama.cpp/ggml/src/ggml-quants.h ./cpp/ggml-quants.h
+cp ./llama.cpp/ggml/src/ggml-quants.c ./cpp/ggml-quants.c
+cp ./llama.cpp/src/unicode.h ./cpp/unicode.h
+cp ./llama.cpp/src/unicode.cpp ./cpp/unicode.cpp
+cp ./llama.cpp/src/unicode-data.h ./cpp/unicode-data.h
+cp ./llama.cpp/src/unicode-data.cpp ./cpp/unicode-data.cpp
+cp ./llama.cpp/ggml/src/llamafile/sgemm.h ./cpp/sgemm.h
+cp ./llama.cpp/ggml/src/llamafile/sgemm.cpp ./cpp/sgemm.cpp
 cp ./llama.cpp/common/log.h ./cpp/log.h
 cp ./llama.cpp/common/common.h ./cpp/common.h
 cp ./llama.cpp/common/common.cpp ./cpp/common.cpp
@@ -34,6 +34,8 @@ cp ./llama.cpp/common/json-schema-to-grammar.h ./cpp/json-schema-to-grammar.h
 cp ./llama.cpp/common/json-schema-to-grammar.cpp ./cpp/json-schema-to-grammar.cpp
 cp ./llama.cpp/common/sampling.h ./cpp/sampling.h
 cp ./llama.cpp/common/sampling.cpp ./cpp/sampling.cpp
+cp ./llama.cpp/ggml/src/ggml-aarch64.h ./cpp/ggml-aarch64.h
+cp ./llama.cpp/ggml/src/ggml-aarch64.c ./cpp/ggml-aarch64.c
 
 # List of files to process
 files=(
@@ -57,6 +59,8 @@ files=(
   "./cpp/ggml-common.h"
   "./cpp/sgemm.cpp"
   "./cpp/json-schema-to-grammar.h"
+  "./cpp/ggml-aarch64.h"
+  "./cpp/ggml-aarch64.c"
 )
 
 # Loop through each file and run the sed commands
@@ -80,7 +84,7 @@ done
 
 echo "Replacement completed successfully!"
 
-yarn example
+# yarn example
 
 # Apply patch
 patch -p0 -d ./cpp < ./scripts/common.h.patch
@@ -92,11 +96,11 @@ patch -p0 -d ./cpp < ./scripts/ggml-metal.m.patch
 
 if [ "$OS" = "Darwin" ]; then
   # Build metallib (~1.4MB)
-  cd llama.cpp
+  cd llama.cpp/ggml/src/
   xcrun --sdk iphoneos metal -c ggml-metal.metal -o ggml-metal.air
   xcrun --sdk iphoneos metallib ggml-metal.air   -o ggml-llama.metallib
   rm ggml-metal.air
-  cp ./ggml-llama.metallib ../cpp/ggml-llama.metallib
+  cp ./ggml-llama.metallib ../../../cpp/ggml-llama.metallib
 
   cd -
 
diff --git a/scripts/common.cpp.patch b/scripts/common.cpp.patch
index 0d03d38a..e7631a72 100644
--- a/scripts/common.cpp.patch
+++ b/scripts/common.cpp.patch
@@ -1,6 +1,6 @@
 --- common.cpp.orig	2024-05-29 09:16:58
 +++ common.cpp	2024-05-29 09:16:59
-@@ -47,6 +47,12 @@
+@@ -51,6 +51,12 @@
  #include 
  #include 
  #endif
diff --git a/scripts/ggml-metal.m.patch b/scripts/ggml-metal.m.patch
index a269d0fd..1592f387 100644
--- a/scripts/ggml-metal.m.patch
+++ b/scripts/ggml-metal.m.patch
@@ -1,6 +1,6 @@
 --- ggml-metal.m.orig	2024-05-29 09:16:58
 +++ ggml-metal.m	2024-05-29 09:16:59
-@@ -334,7 +334,7 @@
+@@ -336,7 +336,7 @@
          const bool try_metallib = true;
  #endif
 
diff --git a/scripts/llama.cpp.patch b/scripts/llama.cpp.patch
index 816ad178..d24585f0 100644
--- a/scripts/llama.cpp.patch
+++ b/scripts/llama.cpp.patch
@@ -1,6 +1,6 @@
 --- llama.cpp.orig	2024-05-29 09:16:58
 +++ llama.cpp	2024-05-29 09:16:59
-@@ -117,6 +117,17 @@
+@@ -129,6 +129,17 @@
  #define LLAMA_LOG_WARN(...)  llama_log_internal(LM_GGML_LOG_LEVEL_WARN , __VA_ARGS__)
  #define LLAMA_LOG_ERROR(...) llama_log_internal(LM_GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
 
@@ -18,7 +18,7 @@
  //
  // helpers
  //
-@@ -1384,16 +1395,16 @@
+@@ -1708,16 +1719,16 @@
 
          if (prefetch > 0) {
              // advise the kernel to preload the mapped memory

From 9722026d9382e3018a4bd751bfb10cd60315f8ec Mon Sep 17 00:00:00 2001
From: jhen 
Date: Mon, 22 Jul 2024 11:59:39 +0800
Subject: [PATCH 2/3] chore(android): format

---
 .../main/java/com/rnllama/LlamaContext.java   | 46 +++++++++----------
 1 file changed, 23 insertions(+), 23 deletions(-)

diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java
index e2044cf1..dad6fb5e 100644
--- a/android/src/main/java/com/rnllama/LlamaContext.java
+++ b/android/src/main/java/com/rnllama/LlamaContext.java
@@ -237,33 +237,33 @@ public void release() {
   static {
     Log.d(NAME, "Primary ABI: " + Build.SUPPORTED_ABIS[0]);
     if (LlamaContext.isArm64V8a()) {
-        String cpuFeatures = LlamaContext.getCpuFeatures();
-        Log.d(NAME, "CPU features: " + cpuFeatures);
+      String cpuFeatures = LlamaContext.getCpuFeatures();
+      Log.d(NAME, "CPU features: " + cpuFeatures);
 
-        boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp");
-        boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
-        boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
-        boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");
+      boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp");
+      boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
+      boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
+      boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");
 
-        if (isAtLeastArmV84 && hasFp16 && hasDotProd) {
-            Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod.so");
-            System.loadLibrary("rnllama_v8_4_fp16_dotprod");
-        } else if (isAtLeastArmV82 && hasFp16 && hasDotProd) {
-            Log.d(NAME, "Loading librnllama_v8_2_fp16_dotprod.so");
-            System.loadLibrary("rnllama_v8_2_fp16_dotprod");
-        } else if (isAtLeastArmV82 && hasFp16) {
-            Log.d(NAME, "Loading librnllama_v8_2_fp16.so");
-            System.loadLibrary("rnllama_v8_2_fp16");
-        } else {
-            Log.d(NAME, "Loading librnllama_v8.so");
-            System.loadLibrary("rnllama_v8");
-        }
+      if (isAtLeastArmV84 && hasFp16 && hasDotProd) {
+        Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod.so");
+        System.loadLibrary("rnllama_v8_4_fp16_dotprod");
+      } else if (isAtLeastArmV82 && hasFp16 && hasDotProd) {
+        Log.d(NAME, "Loading librnllama_v8_2_fp16_dotprod.so");
+        System.loadLibrary("rnllama_v8_2_fp16_dotprod");
+      } else if (isAtLeastArmV82 && hasFp16) {
+        Log.d(NAME, "Loading librnllama_v8_2_fp16.so");
+        System.loadLibrary("rnllama_v8_2_fp16");
+      } else {
+        Log.d(NAME, "Loading librnllama_v8.so");
+        System.loadLibrary("rnllama_v8");
+      }
     } else if (LlamaContext.isX86_64()) {
-        Log.d(NAME, "Loading librnllama_x86_64.so");
-        System.loadLibrary("rnllama_x86_64");
+      Log.d(NAME, "Loading librnllama_x86_64.so");
+      System.loadLibrary("rnllama_x86_64");
     } else {
-        Log.d(NAME, "Loading default librnllama.so");
-        System.loadLibrary("rnllama");
+      Log.d(NAME, "Loading default librnllama.so");
+      System.loadLibrary("rnllama");
     }
   }
 

From 6b731d8996e49988d1d2f4d22ba089590e0f7537 Mon Sep 17 00:00:00 2001
From: jhen 
Date: Mon, 22 Jul 2024 12:00:34 +0800
Subject: [PATCH 3/3] chore: revert unnecessary change

---
 scripts/bootstrap.sh | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh
index 8abbcec0..88a374d3 100755
--- a/scripts/bootstrap.sh
+++ b/scripts/bootstrap.sh
@@ -84,7 +84,7 @@ done
 
 echo "Replacement completed successfully!"
 
-# yarn example
+yarn example
 
 # Apply patch
 patch -p0 -d ./cpp < ./scripts/common.h.patch