diff --git a/common/common.cpp b/common/common.cpp index cdcb352b5a8ae..1591790e6df4c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -200,19 +200,13 @@ void gpt_params_handle_model_default(gpt_params & params) { } params.hf_file = params.model; } else if (params.model.empty()) { - std::string cache_directory = fs_get_cache_directory(); - const bool success = fs_create_directory_with_parents(cache_directory); - if (!success) { - throw std::runtime_error("failed to create cache directory: " + cache_directory); - } - params.model = cache_directory + string_split(params.hf_file, '/').back(); + params.model = fs_get_cache_file(string_split(params.hf_file, '/').back()); } } else if (!params.model_url.empty()) { if (params.model.empty()) { auto f = string_split(params.model_url, '#').front(); f = string_split(f, '?').front(); - f = string_split(f, '/').back(); - params.model = "models/" + f; + params.model = fs_get_cache_file(string_split(f, '/').back()); } } else if (params.model.empty()) { params.model = DEFAULT_MODEL_PATH; @@ -1491,6 +1485,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.chat_template = argv[i]; return true; } + if (arg == "--slot-prompt-similarity" || arg == "-sps") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.slot_prompt_similarity = std::stof(argv[i]); + return true; + } if (arg == "-pps") { params.is_pp_shared = true; return true; @@ -1913,6 +1915,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "set custom jinja chat template (default: template taken from model's metadata)\n" "only commonly used templates are accepted:\n" "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); + options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY", + "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity }); #ifndef LOG_DISABLE_LOGS options.push_back({ "logging" }); @@ -2269,6 +2273,16 @@ std::string fs_get_cache_directory() { return ensure_trailing_slash(cache_directory); } +std::string fs_get_cache_file(const std::string & filename) { + GGML_ASSERT(filename.find(DIRECTORY_SEPARATOR) == std::string::npos); + std::string cache_directory = fs_get_cache_directory(); + const bool success = fs_create_directory_with_parents(cache_directory); + if (!success) { + throw std::runtime_error("failed to create cache directory: " + cache_directory); + } + return cache_directory + filename; +} + // // Model utils diff --git a/common/common.h b/common/common.h index 35f5311e10fe1..2345d855eed3c 100644 --- a/common/common.h +++ b/common/common.h @@ -203,6 +203,8 @@ struct gpt_params { std::string slot_save_path; + float slot_prompt_similarity = 0.5f; + // batched-bench params bool is_pp_shared = false; @@ -275,6 +277,7 @@ bool fs_validate_filename(const std::string & filename); bool fs_create_directory_with_parents(const std::string & path); std::string fs_get_cache_directory(); +std::string fs_get_cache_file(const std::string & filename); // // Model utils diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index a86864f04861b..025405a2c6ce1 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -47,11 +47,12 @@ class Model: _model_classes: dict[str, type[Model]] = {} dir_model: Path - ftype: int + ftype: gguf.LlamaFileType is_big_endian: bool endianess: gguf.GGUFEndian use_temp_file: bool lazy: bool + model_name: str | None part_names: list[str] is_safetensors: bool hparams: dict[str, Any] @@ -64,7 +65,7 @@ class Model: # subclasses should define this! model_arch: gguf.MODEL_ARCH - def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool): + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, model_name: str | None): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") self.dir_model = dir_model @@ -73,10 +74,11 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.use_temp_file = use_temp_file self.lazy = not eager - self.part_names = Model.get_model_part_names(self.dir_model, ".safetensors") + self.model_name = model_name + self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors") self.is_safetensors = len(self.part_names) > 0 if not self.is_safetensors: - self.part_names = Model.get_model_part_names(self.dir_model, ".bin") + self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin") self.hparams = Model.load_hparams(self.dir_model) self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) @@ -94,7 +96,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, ftype_lw: str = ftype_up.lower() # allow templating the file name with the output ftype, useful with the "auto" ftype self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up) - self.gguf_writer = gguf.GGUFWriter(self.fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) + self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) @classmethod def __init_subclass__(cls): @@ -182,7 +184,7 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", " return new_name def set_gguf_parameters(self): - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_block_count(self.block_count) if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None: @@ -324,21 +326,21 @@ def write_tensors(self): def write(self): self.write_tensors() - self.gguf_writer.write_header_to_file() + self.gguf_writer.write_header_to_file(self.fname_out) self.gguf_writer.write_kv_data_to_file() self.gguf_writer.write_tensors_to_file(progress=True) self.gguf_writer.close() def write_vocab(self): - self.gguf_writer.write_header_to_file() + self.gguf_writer.write_header_to_file(self.fname_out) self.gguf_writer.write_kv_data_to_file() self.gguf_writer.close() @staticmethod - def get_model_part_names(dir_model: Path, suffix: str) -> list[str]: + def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]: part_names: list[str] = [] for filename in os.listdir(dir_model): - if filename.endswith(suffix): + if filename.startswith(prefix) and filename.endswith(suffix): part_names.append(filename) part_names.sort() @@ -665,7 +667,7 @@ class GPTNeoXModel(Model): def set_gguf_parameters(self): block_count = self.hparams["num_hidden_layers"] - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_block_count(block_count) @@ -798,7 +800,7 @@ def set_vocab(self): def set_gguf_parameters(self): block_count = self.hparams["n_layers"] - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_context_length(self.hparams["max_seq_len"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_block_count(block_count) @@ -850,7 +852,7 @@ def set_gguf_parameters(self): raise ValueError("gguf: can not find ctx length parameter.") self.gguf_writer.add_file_type(self.ftype) - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_source_hf_repo(hf_repo) self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) @@ -887,7 +889,7 @@ def set_gguf_parameters(self): else: raise ValueError("gguf: can not find ctx length parameter.") - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_source_hf_repo(hf_repo) self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) @@ -1010,7 +1012,7 @@ def set_gguf_parameters(self): else: raise ValueError("gguf: can not find ctx length parameter.") - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_source_hf_repo(hf_repo) self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) @@ -1206,7 +1208,7 @@ def set_gguf_parameters(self): hparams = self.hparams block_count = hparams["num_hidden_layers"] - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_block_count(block_count) @@ -1681,7 +1683,7 @@ class GPT2Model(Model): model_arch = gguf.MODEL_ARCH.GPT2 def set_gguf_parameters(self): - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_block_count(self.hparams["n_layer"]) self.gguf_writer.add_context_length(self.hparams["n_ctx"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) @@ -2248,7 +2250,7 @@ def set_gguf_parameters(self): hparams = self.hparams block_count = hparams["num_hidden_layers"] - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_block_count(block_count) @@ -2348,7 +2350,7 @@ def set_gguf_parameters(self): # Fail early for models which don't have a block expansion factor of 2 assert d_inner == 2 * d_model - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading @@ -2852,7 +2854,7 @@ def main() -> None: logger.error(f"Model {hparams['architectures'][0]} is not supported") sys.exit(1) - model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy) + model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy, args.model_name) logger.info("Set model parameters") model_instance.set_gguf_parameters() diff --git a/examples/rpc/rpc-server.cpp b/examples/rpc/rpc-server.cpp index 62d8282501768..7c15d2aa4acfb 100644 --- a/examples/rpc/rpc-server.cpp +++ b/examples/rpc/rpc-server.cpp @@ -6,10 +6,6 @@ #include "ggml-metal.h" #endif -#ifdef GGML_USE_SYCL -#include "ggml-sycl.h" -#endif - #include "ggml-rpc.h" #ifdef _WIN32 # include @@ -83,12 +79,6 @@ static ggml_backend_t create_backend() { if (!backend) { fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); } -#elif GGML_USE_SYCL - fprintf(stderr, "%s: using SYCL backend\n", __func__); - backend = ggml_backend_sycl_init(0); // init device 0 - if (!backend) { - fprintf(stderr, "%s: ggml_backend_sycl_init() failed\n", __func__); - } #endif // if there aren't GPU Backends fallback to CPU backend diff --git a/examples/server/public/index-new.html b/examples/server/public/index-new.html index d571c27791c72..19c9f643d3027 100644 --- a/examples/server/public/index-new.html +++ b/examples/server/public/index-new.html @@ -416,7 +416,7 @@ message = html`<${Probabilities} data=${data} />` } else { const text = isArrayMessage ? - data.map(msg => msg.content).join('').replace(/^\s+/, '') : + data.map(msg => msg.content).join('') : data; message = isCompletionMode ? text : diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 528220607a4f6..6ffaa8d9fe637 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -647,6 +647,9 @@ struct server_context { server_metrics metrics; + // Necessary similarity of prompt for slot selection + float slot_prompt_similarity = 0.0f; + ~server_context() { if (ctx) { llama_free(ctx); @@ -795,24 +798,88 @@ struct server_context { return prompt_tokens; } - server_slot * get_slot(int id) { - int64_t t_last = ggml_time_us(); - - server_slot * last_used = nullptr; - + server_slot * get_slot_by_id(int id) { for (server_slot & slot : slots) { - if (slot.id == id && slot.available()) { + if (slot.id == id) { return &slot; } + } + + return nullptr; + } + + server_slot * get_available_slot(const std::string & prompt) { + server_slot * ret = nullptr; + + // find the slot that has at least n% prompt similarity + if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) { + int max_lcp_len = 0; + float similarity = 0; + + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (!slot.available()) { + continue; + } + + // skip the slot if it does not contains prompt + if (!slot.prompt.is_string()) { + continue; + } + + // current slot's prompt + std::string slot_prompt = slot.prompt.get(); + + // length of the current slot's prompt + int slot_prompt_len = slot_prompt.size(); + + // length of the Longest Common Prefix between the current slot's prompt and the input prompt + int lcp_len = common_part(slot_prompt, prompt); + + // fraction of the common substring length compared to the current slot's prompt length + similarity = static_cast(lcp_len) / slot_prompt_len; + + // select the current slot if the criteria match + if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) { + max_lcp_len = lcp_len; + ret = &slot; + } + } - // among all available slots, find the one that has been least recently used - if (slot.available() && slot.t_last_used < t_last) { - last_used = &slot; - t_last = slot.t_last_used; + if (ret != nullptr) { + LOG_VERBOSE("selected slot by lcp similarity", { + {"id_slot", ret->id}, + {"max_lcp_len", max_lcp_len}, + {"similarity", similarity}, + }); } } - return last_used; + // find the slot that has been least recently used + if (ret == nullptr) { + int64_t t_last = ggml_time_us(); + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (!slot.available()) { + continue; + } + + // select the current slot if the criteria match + if (slot.t_last_used < t_last) { + t_last = slot.t_last_used; + ret = &slot; + } + } + + if (ret != nullptr) { + LOG_VERBOSE("selected slot by lru", { + {"id_slot", ret->id}, + {"t_last", t_last}, + }); + } + } + + return ret; } bool launch_slot_with_task(server_slot & slot, const server_task & task) { @@ -1515,13 +1582,29 @@ struct server_context { switch (task.type) { case SERVER_TASK_TYPE_COMPLETION: { - server_slot * slot = get_slot(json_value(task.data, "id_slot", -1)); + int id_slot = json_value(task.data, "id_slot", -1); + std::string prompt = json_value(task.data, "prompt", std::string()); + + server_slot * slot; + + if (id_slot != -1) { + slot = get_slot_by_id(id_slot); + } else { + slot = get_available_slot(prompt); + } + if (slot == nullptr) { // if no slot is available, we defer this task for processing later LOG_VERBOSE("no slot is available", {{"id_task", task.id}}); queue_tasks.defer(task); break; } + if (!slot->available()) { + // if requested slot is unavailable, we defer this task for processing later + LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); + queue_tasks.defer(task); + break; + } if (task.data.contains("system_prompt")) { std::string sys_prompt = json_value(task.data, "system_prompt", std::string()); @@ -1638,11 +1721,17 @@ struct server_context { case SERVER_TASK_TYPE_SLOT_SAVE: { int id_slot = task.data.at("id_slot"); - server_slot * slot = get_slot(id_slot); + server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); break; } + if (!slot->available()) { + // if requested slot is unavailable, we defer this task for processing later + LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); + queue_tasks.defer(task); + break; + } const size_t token_count = slot->cache_tokens.size(); const int64_t t_start = ggml_time_us(); @@ -1673,11 +1762,17 @@ struct server_context { case SERVER_TASK_TYPE_SLOT_RESTORE: { int id_slot = task.data.at("id_slot"); - server_slot * slot = get_slot(id_slot); + server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); break; } + if (!slot->available()) { + // if requested slot is unavailable, we defer this task for processing later + LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); + queue_tasks.defer(task); + break; + } const int64_t t_start = ggml_time_us(); @@ -1715,11 +1810,17 @@ struct server_context { case SERVER_TASK_TYPE_SLOT_ERASE: { int id_slot = task.data.at("id_slot"); - server_slot * slot = get_slot(id_slot); + server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); break; } + if (!slot->available()) { + // if requested slot is unavailable, we defer this task for processing later + LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); + queue_tasks.defer(task); + break; + } // Erase token cache const size_t n_erased = slot->cache_tokens.size(); @@ -2467,6 +2568,9 @@ int main(int argc, char ** argv) { log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded"; } + // Necessary similarity of prompt for slot selection + ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; + // load the model if (!ctx_server.load_model(params)) { state.store(SERVER_STATE_ERROR); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index b7bfb41d35edc..63fde9c9faabe 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -253,6 +253,13 @@ static size_t common_part(const std::vector & a, const std::vector< return i; } +static size_t common_part(const std::string & a, const std::string & b) { + size_t i; + for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} + + return i; +} + static bool ends_with(const std::string & str, const std::string & suffix) { return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } diff --git a/ggml-cuda.cu b/ggml-cuda.cu index dad8a9e2dafe7..af10f21a0a92a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1347,10 +1347,30 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { GGML_UNUSED(main_device); } +static cudaError_t ggml_cuda_Memcpy2DPeerAsync( + void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) { + +#if !defined(GGML_USE_HIPBLAS) + // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices + cudaMemcpy3DPeerParms p = {}; + p.dstDevice = dstDevice; + p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height); + p.srcDevice = srcDevice; + p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height); + p.extent = make_cudaExtent(width, height, 1); + return cudaMemcpy3DPeerAsync(&p, stream); +#else + // HIP does not support cudaMemcpy3DPeerAsync or vmm pools + GGML_UNUSED(dstDevice); + GGML_UNUSED(srcDevice); + return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream); +#endif // !defined(GGML_USE_HIPBLAS) +} + static void ggml_cuda_op_mul_mat( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, - const bool convert_src1_to_q8_1) { + quantize_cuda_t quantize_src1) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -1407,7 +1427,9 @@ static void ggml_cuda_op_mul_mat( } struct dev_data { - ggml_cuda_pool_alloc src0_dd_alloc; + int cc; + + ggml_cuda_pool_alloc src0_dd_alloc; ggml_cuda_pool_alloc src1_ddf_alloc; ggml_cuda_pool_alloc src1_ddq_alloc; ggml_cuda_pool_alloc dst_dd_alloc; @@ -1426,6 +1448,8 @@ static void ggml_cuda_op_mul_mat( int used_devices = 0; for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + dev[id].cc = ggml_cuda_info().devices[id].cc; + // by default, use all rows dev[id].row_low = 0; dev[id].row_high = ne01; @@ -1476,11 +1500,15 @@ static void ggml_cuda_op_mul_mat( dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1)); } - if (convert_src1_to_q8_1) { - dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs); + if (quantize_src1) { + size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs; + if (quantize_src1 == quantize_mmq_q8_1_cuda) { + src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq); + } + dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size); if (src1_on_device && src1_is_contiguous) { - quantize_row_q8_1_cuda(dev[id].src1_ddf, dev[id].src1_ddq, ne10, nrows1, src1_padded_col_size, stream); + quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream); CUDA_CHECK(cudaGetLastError()); } } @@ -1526,7 +1554,12 @@ static void ggml_cuda_op_mul_mat( const int64_t i03 = i0 / ne12; const int64_t i02 = i0 % ne12; - const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs; + size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs; + if (quantize_src1 == quantize_mmq_q8_1_cuda) { + src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq); + } else { + src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs; + } // for split tensors the data begins at i0 == i0_offset_low char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; @@ -1543,10 +1576,17 @@ static void ggml_cuda_op_mul_mat( // copy src0, src1 to device if necessary if (src1_is_contiguous) { if (id != ctx.device) { - if (convert_src1_to_q8_1) { + if (quantize_src1) { char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset; - CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddq_i, id, src1_ddq_i_source, ctx.device, - src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream)); + if (quantize_src1 == quantize_mmq_q8_1_cuda) { + const size_t pitch = ne11*sizeof(block_q8_1_mmq); + const size_t width = src1_ncols*sizeof(block_q8_1_mmq); + const size_t height = src1_padded_col_size/(4*QK8_1); + CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream)); + } else { + CUDA_CHECK(cudaMemcpyPeerAsync( + src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream)); + } } else { float * src1_ddf_i_source = (float *) src1->data; src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10; @@ -1561,8 +1601,8 @@ static void ggml_cuda_op_mul_mat( GGML_ASSERT(false); } - if (convert_src1_to_q8_1 && !src1_is_contiguous) { - quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream); + if (quantize_src1 && !src1_is_contiguous) { + quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream); CUDA_CHECK(cudaGetLastError()); } @@ -1587,22 +1627,8 @@ static void ggml_cuda_op_mul_mat( float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); dhf_dst_i += src1_col_0*ne0 + dev[id].row_low; -#if !defined(GGML_USE_HIPBLAS) - // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices - cudaMemcpy3DPeerParms p = {}; - p.dstDevice = ctx.device; - p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), row_diff, src1_ncols); - p.srcDevice = id; - p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols); - p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1); - CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream)); -#else - // HIP does not support cudaMemcpy3DPeerAsync or vmm pools - CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), - dst_dd_i, row_diff*sizeof(float), - row_diff*sizeof(float), src1_ncols, - cudaMemcpyDeviceToDevice, stream)); -#endif + CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync( + dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream)); } else { float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); @@ -1941,13 +1967,13 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor // KQ + KQV multi-batch ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr); } else if (use_mul_mat_vec_q) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true); + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, true); + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda); } else { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); } } diff --git a/ggml-cuda/mmq.cu b/ggml-cuda/mmq.cu index 58799e4caf6f8..1d6b9e6982b6e 100644 --- a/ggml-cuda/mmq.cu +++ b/ggml-cuda/mmq.cu @@ -11,6 +11,7 @@ void ggml_cuda_op_mul_mat_q( const int64_t nb01 = src0->nb[1]; const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; GGML_ASSERT(ne10 % QK8_1 == 0); const int64_t ne0 = dst->ne[0]; @@ -25,7 +26,7 @@ void ggml_cuda_op_mul_mat_q( // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, nrows_dst}; + const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst}; switch (src0->type) { case GGML_TYPE_Q4_0: diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 6744cce6d785f..3ccae8a0c36fa 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -1,15 +1,26 @@ +#pragma once + #include "common.cuh" #include "vecdotq.cuh" #include #include +#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) + typedef void (*load_tiles_mmq_t)( const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride); typedef void (*vec_dot_mmq_t)( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, float * __restrict__ sum, const int & k0); + const int * __restrict__ y, float * __restrict__ sum, const int & k0); + +struct block_q8_1_mmq { + half2 ds[4]; + int8_t qs[4*QK8_1]; +}; +static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); +static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size"); struct tile_x_sizes { int ql; @@ -132,10 +143,14 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + const float * x_dmf = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -145,19 +160,18 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat( const int i = i0 + threadIdx.x; const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - const float * x_dmf = (const float *) x_dm; int u[2*VDR_Q4_0_Q8_1_MMQ]; #pragma unroll for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_0) % WARP_SIZE]; } sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dmf[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], + y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } @@ -203,10 +217,13 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -221,13 +238,13 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat( #pragma unroll for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_1) % WARP_SIZE]; } sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1], + y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } @@ -293,10 +310,14 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + const float * x_dmf = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -306,20 +327,18 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat( const int i = i0 + threadIdx.x; const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; + const int index_bx = i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0; int u[2*VDR_Q5_0_Q8_1_MMQ]; #pragma unroll for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_0) % WARP_SIZE]; } sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } @@ -383,10 +402,13 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -396,18 +418,18 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat( const int i = i0 + threadIdx.x; const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k0/QI5_1; + const int index_bx = i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1; int u[2*VDR_Q5_1_Q8_1_MMQ]; #pragma unroll for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_1) % WARP_SIZE]; } sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } @@ -455,10 +477,14 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + const float * x_dmf = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -467,12 +493,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat( for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[j * WARP_SIZE + k0], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0], - y_df[j * (WARP_SIZE/QI8_1) + k0/QI8_1]); + (&x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0], + y_df[j*MMQ_TILE_Y_K + k0/QI8_1]); } } } @@ -531,10 +554,13 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -545,11 +571,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat( const int kbx = k0 / QI2_K; const int ky = (k0 % QI2_K) * QR2_K; - const float * y_df = (const float *) y_ds; int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); + const int kqsx = i*(WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); #pragma unroll @@ -557,11 +582,11 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat( v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; } - const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; + const uint8_t * scales = ((const uint8_t *) &x_sc[i*(WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; - const int index_y = j * WARP_SIZE + (QR2_K*k0) % WARP_SIZE; sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( - v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); + v, &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE], scales, + x_dm[i*(WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]); } } } @@ -646,7 +671,11 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + + const float * x_dmf = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -658,8 +687,6 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat( const int kbx = k0 / QI3_K; const int ky = (k0 % QI3_K) * QR3_K; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; @@ -667,19 +694,19 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat( #pragma unroll for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); + const int kqsx = i*(WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); const int shift = 2 * ((ky % 32) / 8); const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; - const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); + const int vh = x_qh[i*(WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); const int vlh = (vh << 2) & 0x04040404; v[l] = __vsubss4(vll, vlh); } - const int index_y = j * WARP_SIZE + (k0*QR3_K) % WARP_SIZE; sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq( - v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); + v, &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales, + x_dmf[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]); } } } @@ -746,10 +773,13 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -760,9 +790,9 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat( const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8); - const int index_y = j * WARP_SIZE + (QR4_K*k0) % WARP_SIZE; sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq( - &x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); + &x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8, + x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]); } } } @@ -842,10 +872,13 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -856,10 +889,9 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat( const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8); - const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k0; - const int index_y = j * WARP_SIZE + (QR5_K*k0) % WARP_SIZE; sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq( - &x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); + &x_ql[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8, + x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]); } } } @@ -932,10 +964,14 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); + const float * x_dmf = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -944,15 +980,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat( for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]); - const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k0; - const int index_y = j * WARP_SIZE + (QR6_K*k0) % WARP_SIZE; sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq( - &x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); + &x_ql[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc, + x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]); } } } @@ -964,7 +996,6 @@ struct mmq_type_traits; template struct mmq_type_traits { - static constexpr bool need_sum = true; static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat; @@ -972,7 +1003,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = true; static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat; @@ -980,7 +1010,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = false; static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat; @@ -988,7 +1017,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = true; static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat; @@ -996,7 +1024,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = false; static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat; @@ -1004,7 +1031,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = false; static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat; @@ -1012,7 +1038,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = false; static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat; @@ -1020,7 +1045,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = true; static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat; @@ -1028,7 +1052,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = true; static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat; @@ -1036,12 +1059,36 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = false; static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat; }; +static int mmq_need_sum(const ggml_type type_x) { + switch (type_x) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return true; + case GGML_TYPE_Q5_0: + return false; + case GGML_TYPE_Q5_1: + return true; + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + return false; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return true; + case GGML_TYPE_Q6_K: + return false; + default: + GGML_ASSERT(false); + break; + } + return false; +} + template #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA3) || defined(RDNA2) @@ -1056,7 +1103,7 @@ template #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) static __global__ void mul_mat_q( const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, - const int ne00, const int ne01, const int stride00, const int ne10, const int ne11, const int ne0) { + const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) { // Skip unused template specializations for faster compilation: if (mmq_x > get_mmq_x_max_device()) { @@ -1068,7 +1115,6 @@ static __global__ void mul_mat_q( constexpr int qr = ggml_cuda_type_traits::qr; constexpr int qi = ggml_cuda_type_traits::qi; constexpr int mmq_y = get_mmq_y_device(mmq_x); - constexpr bool need_sum = mmq_type_traits::need_sum; constexpr int vdr = mmq_type_traits::vdr; constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot; @@ -1080,62 +1126,38 @@ static __global__ void mul_mat_q( half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql); int * tile_x_qh = (int *) (tile_x_dm + txs.dm); int * tile_x_sc = (int *) (tile_x_qh + txs.qh); - int * tile_y_qs = (int *) (tile_x_sc + txs.sc); // [mmq_x * WARP_SIZE] - half2 * tile_y_ds = (half2 *) (tile_y_qs + mmq_x*WARP_SIZE); // [mmq_x * WARP_SIZE/QI8_1]; - - const block_q8_1 * y = (const block_q8_1 *) yc; + int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)] const int blocks_per_row_x = ne00 / qk; - const int blocks_per_col_y = ne10 / QK8_1; const int blocks_per_warp = WARP_SIZE / qi; const int & ne1 = ne11; const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1; + const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); + float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f}; for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) { - load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride00*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride00); + load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01); #pragma unroll for (int kr = 0; kr < qr; ++kr) { - const int kqs = kr*WARP_SIZE + threadIdx.x; - const int kbxd = kqs / QI8_1; - + const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int)); #pragma unroll - for (int i0 = 0; i0 < mmq_x; i0 += nwarps) { - const int i = min(blockIdx.y*mmq_x + threadIdx.y + i0, ne11-1); // to prevent out-of-bounds memory accesses - - const block_q8_1 * by0 = &y[i*blocks_per_col_y + kb0 * (qk/QK8_1) + kbxd]; + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { + int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; - const int index_y = (i0 + threadIdx.y) * WARP_SIZE + kqs % WARP_SIZE; - tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); - } - -#pragma unroll - for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { - const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x; - const int kby = threadIdx.x % (WARP_SIZE/QI8_1); - const int i_y_eff = min(blockIdx.y*mmq_x + ids, ne11-1); - - // if the sum is not needed it's faster to transform the scale to f32 ahead of time - const half2 * dsi_src = &y[i_y_eff*blocks_per_col_y + kb0 * (qk/QK8_1) + kr*(WARP_SIZE/QI8_1) + kby].ds; - half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby]; - if (need_sum) { - *dsi_dst = *dsi_src; - } else { - float * dfi_dst = (float *) dsi_dst; - *dfi_dst = __low2float(*dsi_src); - } + tile_y[l] = by0[l]; } __syncthreads(); // #pragma unroll // unrolling this loop causes too much register pressure for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) { - vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, sum, k0); + vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y, sum, k0); } __syncthreads(); @@ -1165,8 +1187,8 @@ static __global__ void mul_mat_q( struct mmq_args { const char * x; const char * y; float * dst; - int64_t ne00; int64_t ne01; int64_t stride00; - int64_t ne10; int64_t ne11; + int64_t ne00; int64_t ne01; int64_t stride01; + int64_t ne10; int64_t ne11; int64_t stride11; int64_t ne0; }; @@ -1184,7 +1206,7 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) { const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y); const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int); const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2); - const int shmem = shmem_x + shmem_y; + const int shmem = shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int)); #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1198,11 +1220,11 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) { if (args.ne01 % mmq_y == 0) { const bool need_check = false; mul_mat_q<<>> - (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0); + (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); } else { const bool need_check = true; mul_mat_q<<>> - (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0); + (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); } } diff --git a/ggml-cuda/quantize.cu b/ggml-cuda/quantize.cu index 7578c4b6c7cab..b4678682238d3 100644 --- a/ggml-cuda/quantize.cu +++ b/ggml-cuda/quantize.cu @@ -1,22 +1,23 @@ #include "quantize.cuh" +#include -static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx_padded) { - const int64_t ix = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) { + const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; - if (ix >= kx_padded) { + if (ix0 >= kx0_padded) { return; } - const int64_t iy = (int64_t)blockDim.y*blockIdx.y + threadIdx.y; + const int64_t ix1 = blockIdx.y; - const int64_t i_padded = (int64_t)iy*kx_padded + ix; + const int64_t i_padded = ix1*kx0_padded + ix0; block_q8_1 * y = (block_q8_1 *) vy; const int64_t ib = i_padded / QK8_1; // block index const int64_t iqs = i_padded % QK8_1; // quant index - const float xi = ix < kx ? x[iy*kx + ix] : 0.0f; + const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f; float amax = fabsf(xi); float sum = xi; @@ -36,10 +37,76 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest reinterpret_cast(y[ib].ds.y) = sum; } -void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream) { - const int64_t block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; - const dim3 num_blocks(block_num_x, ky, 1); +template +static __global__ void quantize_mmq_q8_1( + const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { + + const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; + + if (ix0 >= kx0_padded) { + return; + } + + const int64_t ix1 = kx1*blockIdx.z + blockIdx.y; + + block_q8_1_mmq * y = (block_q8_1_mmq *) vy; + + const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel + const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel + const int64_t iqs = ix0 % (4*QK8_1); // quant index in block + + const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f; + float amax = fabsf(xi); + + amax = warp_reduce_max(amax); + + float sum; + if (need_sum) { + sum = warp_reduce_sum(xi); + } + + const float d = amax / 127; + const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + y[ib].qs[iqs] = q; + + if (iqs % QK8_1 != 0) { + return; + } + + if (need_sum) { + y[ib].ds[iqs/QK8_1] = make_half2(d, sum); + } else { + ((float *) y[ib].ds)[iqs/QK8_1] = d; + } +} + +void quantize_row_q8_1_cuda( + const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, + const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { + + GGML_ASSERT(kx0_padded % QK8_1 == 0); + + const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + const dim3 num_blocks(block_num_x, kx1*channels, 1); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<>>(x, vy, kx, kx_padded); + quantize_q8_1<<>>(x, vy, kx0, kx0_padded); + + GGML_UNUSED(type_x); } +void quantize_mmq_q8_1_cuda( + const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, + const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { + + GGML_ASSERT(kx0_padded % (4*QK8_1) == 0); + + const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + const dim3 num_blocks(block_num_x, kx1, channels); + const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); + if (mmq_need_sum(type_x)) { + quantize_mmq_q8_1<<>>(x, vy, kx0, kx1, kx0_padded); + } else { + quantize_mmq_q8_1<<>>(x, vy, kx0, kx1, kx0_padded); + } +} diff --git a/ggml-cuda/quantize.cuh b/ggml-cuda/quantize.cuh index b37a4752f2d24..486c9360a46fd 100644 --- a/ggml-cuda/quantize.cuh +++ b/ggml-cuda/quantize.cuh @@ -1,5 +1,20 @@ +#pragma once + #include "common.cuh" +#include "mmq.cuh" + +#include #define CUDA_QUANTIZE_BLOCK_SIZE 256 -void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream); +typedef void (*quantize_cuda_t)( + const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, + const ggml_type type_x, cudaStream_t stream); + +void quantize_row_q8_1_cuda( + const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, + const ggml_type type_x, cudaStream_t stream); + +void quantize_mmq_q8_1_cuda( + const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, + const ggml_type type_x, cudaStream_t stream); diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index e0c512c0dab0f..128769177f102 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -345,15 +345,12 @@ struct vk_context { }; struct ggml_tensor_extra_gpu { - bool ready; - size_t ctx_idx; vk_buffer_ref buffer_gpu; uint64_t offset; void reset() { - ready = false; ctx_idx = 0; buffer_gpu.reset(); offset = 0; @@ -2949,7 +2946,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su const uint64_t d_sz = sizeof(float) * d_ne; vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset; + const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03); vk_buffer d_X; @@ -2958,12 +2955,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su uint64_t y_buf_offset = 0; if (!src0_uma) { d_Qx = extra_src0->buffer_gpu.lock(); - qx_buf_offset = extra_src0->offset; + qx_buf_offset = extra_src0->offset + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); } if (!src1_uma) { d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset; + qy_buf_offset = extra_src1->offset + src1->view_offs; GGML_ASSERT(d_Qy != nullptr); } if (qx_needs_dequant) { @@ -3114,7 +3111,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context const uint64_t d_sz = sizeof(float) * d_ne; vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset; + const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); vk_buffer d_X; uint64_t x_buf_offset = 0; @@ -3122,12 +3119,12 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context uint64_t y_buf_offset = 0; if(!src0_uma) { d_Qx = extra_src0->buffer_gpu.lock(); - qx_buf_offset = extra_src0->offset; + qx_buf_offset = extra_src0->offset + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); } if(!src1_uma) { d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset; + qy_buf_offset = extra_src1->offset + src1->view_offs; GGML_ASSERT(d_Qy != nullptr); } if (qx_needs_dequant) { @@ -3246,14 +3243,14 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c const uint64_t d_sz = sizeof(float) * d_ne; vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset; + const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); vk_buffer d_Qx = extra_src0->buffer_gpu.lock(); - const uint64_t qx_buf_offset = extra_src0->offset; + const uint64_t qx_buf_offset = extra_src0->offset + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); if (!src1_uma) { d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset; + qy_buf_offset = extra_src1->offset + src1->view_offs; GGML_ASSERT(d_Qx != nullptr); } @@ -3323,14 +3320,14 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint64_t d_sz = sizeof(float) * d_ne; vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset; + const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); vk_buffer d_Qx = extra_src0->buffer_gpu.lock(); - const uint64_t qx_buf_offset = extra_src0->offset; + const uint64_t qx_buf_offset = extra_src0->offset + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); if (!src1_uma) { d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset; + qy_buf_offset = extra_src1->offset + src1->view_offs; GGML_ASSERT(d_Qx != nullptr); } @@ -3459,7 +3456,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * const uint64_t d_sz = sizeof(float) * d_ne; vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset; + const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); vk_buffer d_X; uint64_t x_buf_offset = 0; @@ -3467,17 +3464,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * uint64_t y_buf_offset = 0; if (!src0_uma) { d_Qx = extra_src0->buffer_gpu.lock(); - qx_buf_offset = extra_src0->offset; + qx_buf_offset = extra_src0->offset + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); } if (!src1_uma) { d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset; + qy_buf_offset = extra_src1->offset + src1->view_offs; GGML_ASSERT(d_Qy != nullptr); } if (!ids_uma) { d_ids = extra_ids->buffer_gpu.lock(); - ids_buf_offset = extra_ids->offset; + ids_buf_offset = extra_ids->offset + ids->view_offs; GGML_ASSERT(d_ids != nullptr); } if (qx_needs_dequant) { @@ -3636,7 +3633,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const uint64_t d_sz = sizeof(float) * d_ne; vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset; + const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); vk_buffer d_X; uint64_t x_buf_offset = 0; @@ -3644,17 +3641,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte uint64_t y_buf_offset = 0; if(!src0_uma) { d_Qx = extra_src0->buffer_gpu.lock(); - qx_buf_offset = extra_src0->offset; + qx_buf_offset = extra_src0->offset + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); } if(!src1_uma) { d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset; + qy_buf_offset = extra_src1->offset + src1->view_offs; GGML_ASSERT(d_Qy != nullptr); } if(!ids_uma) { d_ids = extra_ids->buffer_gpu.lock(); - ids_buf_offset = extra_ids->offset; + ids_buf_offset = extra_ids->offset + ids->view_offs; GGML_ASSERT(d_ids != nullptr); } if (qx_needs_dequant) { @@ -3769,9 +3766,9 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; const vk_buffer src_buf = extra_src0->buffer_gpu.lock(); - const uint64_t src_offset = extra_src0->offset; + const uint64_t src_offset = extra_src0->offset + src0->view_offs; vk_buffer dst_buf = extra->buffer_gpu.lock(); - const uint64_t dst_offset = extra->offset; + const uint64_t dst_offset = extra->offset + dst->view_offs; std::vector copies; @@ -4062,21 +4059,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c } GGML_ASSERT(d_D != nullptr); - uint64_t d_buf_offset = (extra->offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + uint64_t d_buf_offset = ((extra->offset + dst->view_offs) / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; GGML_ASSERT(d_buf_offset == extra->offset || op == GGML_OP_CPY); // NOLINT if(!src0_uma) { d_X = extra_src0->buffer_gpu.lock(); - x_buf_offset = extra_src0->offset; + x_buf_offset = extra_src0->offset + src0->view_offs; GGML_ASSERT(d_X != nullptr); } if (use_src1 && !src1_uma) { d_Y = extra_src1->buffer_gpu.lock(); - y_buf_offset = extra_src1->offset; + y_buf_offset = extra_src1->offset + src1->view_offs; GGML_ASSERT(d_Y != nullptr); } if (use_src2 && !src2_uma) { d_Z = extra_src2->buffer_gpu.lock(); - z_buf_offset = extra_src2->offset; + z_buf_offset = extra_src2->offset + src2->view_offs; GGML_ASSERT(d_Z != nullptr); } @@ -4336,7 +4333,7 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, cons ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - const uint32_t d_offset = (extra->offset % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size; + const uint32_t d_offset = ((extra->offset + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, { (uint32_t)ggml_nelements(src0), @@ -5569,6 +5566,13 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod const ggml_tensor * src2 = node->src[2]; switch (node->op) { + // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_NONE: + return; case GGML_OP_UNARY: switch (ggml_get_unary_op(node)) { case GGML_UNARY_OP_SILU: @@ -5590,10 +5594,6 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_DIAG_MASK_INF: @@ -5601,7 +5601,6 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_ROPE: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: - case GGML_OP_NONE: case GGML_OP_ARGSORT: case GGML_OP_SUM_ROWS: break; @@ -5654,12 +5653,6 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_DUP: ggml_vk_cpy(ctx, ctx->compute_ctx, src0, node); - break; - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - case GGML_OP_NONE: break; case GGML_OP_NORM: ggml_vk_norm(ctx, ctx->compute_ctx, src0, node); @@ -5712,7 +5705,6 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return; } - extra->ready = true; extra->ctx_idx = ctx->compute_ctx->idx; #ifdef GGML_VULKAN_CHECK_RESULTS @@ -5796,8 +5788,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_ ggml_vk_check_results_0(ctx, params, tensor); #endif - GGML_ASSERT(extra->ready); - vk_context& subctx = ctx->gc.contexts[extra->ctx_idx]; // Only run if ctx hasn't been submitted yet @@ -5822,8 +5812,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_ subctx.out_memcpys.clear(); } - extra->ready = false; - return true; } @@ -5943,7 +5931,9 @@ struct ggml_backend_vk_buffer_context { ~ggml_backend_vk_buffer_context() { ggml_vk_destroy_buffer(dev_buffer); - delete[] temp_tensor_extras; + if (temp_tensor_extras != nullptr) { + delete[] temp_tensor_extras; + } } ggml_tensor_extra_gpu * ggml_vk_alloc_temp_tensor_extra() { @@ -5990,18 +5980,16 @@ GGML_CALL static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t b #endif ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; - ggml_tensor_extra_gpu * extra = ctx->ggml_vk_alloc_temp_tensor_extra(); - if (tensor->view_src != nullptr && tensor->view_src->extra != nullptr) { + if (tensor->view_src != nullptr) { GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); - ggml_tensor_extra_gpu * extra_view = (ggml_tensor_extra_gpu *) tensor->view_src->extra; - extra->buffer_gpu = extra_view->buffer_gpu; - extra->offset = extra_view->offset + tensor->view_offs; + GGML_ASSERT(tensor->view_src->extra != nullptr); + tensor->extra = tensor->view_src->extra; } else { + ggml_tensor_extra_gpu * extra = ctx->ggml_vk_alloc_temp_tensor_extra(); extra->buffer_gpu = ctx->dev_buffer; extra->offset = (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base; + tensor->extra = extra; } - - tensor->extra = extra; } GGML_CALL static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { @@ -6014,7 +6002,7 @@ GGML_CALL static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t bu vk_buffer buf = extra->buffer_gpu.lock(); - ggml_vk_buffer_write(ctx->ctx, buf, extra->offset + offset, data, size); + ggml_vk_buffer_write(ctx->ctx, buf, extra->offset + tensor->view_offs + offset, data, size); } GGML_CALL static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { @@ -6027,7 +6015,7 @@ GGML_CALL static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t bu vk_buffer buf = extra->buffer_gpu.lock(); - ggml_vk_buffer_read(ctx->ctx, buf, extra->offset + offset, data, size); + ggml_vk_buffer_read(ctx->ctx, buf, extra->offset + tensor->view_offs + offset, data, size); } GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { @@ -6038,7 +6026,7 @@ GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t bu vk_buffer src_buf = src_extra->buffer_gpu.lock(); vk_buffer dst_buf = dst_extra->buffer_gpu.lock(); - ggml_vk_buffer_copy(dst_buf, dst_extra->offset, src_buf, src_extra->offset, ggml_nbytes(src)); + ggml_vk_buffer_copy(dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src)); return true; } @@ -6264,7 +6252,7 @@ GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, g vk_buffer buf = extra->buffer_gpu.lock(); - ggml_vk_buffer_write_async(ctx, ctx->transfer_ctx, buf, extra->offset + offset, data, size); + ggml_vk_buffer_write_async(ctx, ctx->transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size); } GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { @@ -6284,7 +6272,7 @@ GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, c vk_buffer buf = extra->buffer_gpu.lock(); - ggml_vk_buffer_read_async(ctx, ctx->transfer_ctx, buf, extra->offset + offset, data, size); + ggml_vk_buffer_read_async(ctx, ctx->transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size); } GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { @@ -6305,7 +6293,7 @@ GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, c vk_buffer src_buf = src_extra->buffer_gpu.lock(); vk_buffer dst_buf = dst_extra->buffer_gpu.lock(); - ggml_vk_buffer_copy_async(ctx->transfer_ctx, dst_buf, dst_extra->offset, src_buf, src_extra->offset, ggml_nbytes(src)); + ggml_vk_buffer_copy_async(ctx->transfer_ctx, dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src)); return true; } @@ -6478,11 +6466,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const // return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; // } break; case GGML_OP_ROPE: - { - const int mode = ((const int32_t *) op->op_params)[2]; - - return true; - } break; + return true; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -6725,7 +6709,7 @@ static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tenso ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; vk_buffer buffer_gpu = extra->buffer_gpu.lock(); - ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset, tensor_data, tensor_size); + ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size); } std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl; @@ -6809,7 +6793,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ } else if (ggml_backend_buffer_is_vk(src0->buffer)) { ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src0->extra; vk_buffer buffer_gpu = extra->buffer_gpu.lock(); - uint64_t offset = extra->offset; + uint64_t offset = extra->offset + src0->view_offs; if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) { for (int i3 = 0; i3 < src0->ne[3]; i3++) { for (int i2 = 0; i2 < src0->ne[2]; i2++) { @@ -6851,7 +6835,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ } else if (ggml_backend_buffer_is_vk(src1->buffer)) { ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src1->extra; vk_buffer buffer_gpu = extra->buffer_gpu.lock(); - uint64_t offset = extra->offset; + uint64_t offset = extra->offset + src1->view_offs; if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) { for (int i3 = 0; i3 < src1->ne[3]; i3++) { for (int i2 = 0; i2 < src1->ne[2]; i2++) { @@ -6909,7 +6893,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ } else if (ggml_backend_buffer_is_vk(src2->buffer)) { ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src2->extra; vk_buffer buffer_gpu = extra->buffer_gpu.lock(); - uint64_t offset = extra->offset; + uint64_t offset = extra->offset + src2->view_offs; if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) { for (int i3 = 0; i3 < src2->ne[3]; i3++) { for (int i2 = 0; i2 < src2->ne[2]; i2++) { @@ -7092,11 +7076,11 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; vk_buffer buffer_gpu = extra->buffer_gpu.lock(); - if (extra->offset + tensor_size >= buffer_gpu->size) { - tensor_size = buffer_gpu->size - (extra->offset); + if (extra->offset + tensor->view_offs + tensor_size >= buffer_gpu->size) { + tensor_size = buffer_gpu->size - (extra->offset + tensor->view_offs); } - ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset, tensor_data, tensor_size); + ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size); } float first_error_result = -1.0f; diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index b93747aff58b3..ed56abfb3c2ea 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -5,6 +5,7 @@ import shutil import struct import tempfile +from dataclasses import dataclass from enum import Enum, auto from io import BufferedWriter from typing import IO, Any, Sequence, Mapping @@ -30,17 +31,36 @@ logger = logging.getLogger(__name__) +@dataclass +class TensorInfo: + shape: Sequence[int] + dtype: GGMLQuantizationType + nbytes: int + tensor: np.ndarray[Any, Any] | None = None + + +@dataclass +class GGUFValue: + value: Any + type: GGUFValueType + + class WriterState(Enum): + NO_FILE = auto() EMPTY = auto() HEADER = auto() KV_DATA = auto() TI_DATA = auto() + WEIGHTS = auto() class GGUFWriter: - fout: BufferedWriter + fout: BufferedWriter | None + path: os.PathLike[str] | str | None temp_file: tempfile.SpooledTemporaryFile[bytes] | None - tensors: list[np.ndarray[Any, Any]] + tensors: dict[str, TensorInfo] + kv_data: dict[str, GGUFValue] + state: WriterState _simple_value_packing = { GGUFValueType.UINT8: "B", GGUFValueType.INT8: "b", @@ -56,141 +76,140 @@ class GGUFWriter: } def __init__( - self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True, + self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE, ): - self.fout = open(path, "wb") + self.fout = None + self.path = path self.arch = arch self.endianess = endianess - self.offset_tensor = 0 self.data_alignment = GGUF_DEFAULT_ALIGNMENT - self.kv_data = bytearray() - self.kv_data_count = 0 - self.ti_data = bytearray() - self.ti_data_count = 0 - self.ti_names = set() self.use_temp_file = use_temp_file self.temp_file = None - self.tensors = [] + self.tensors = dict() + self.kv_data = dict() logger.info("gguf: This GGUF file is for {0} Endian only".format( "Big" if self.endianess == GGUFEndian.BIG else "Little", )) - self.state = WriterState.EMPTY + self.state = WriterState.NO_FILE self.add_architecture() - def write_header_to_file(self) -> None: + def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None: + if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path): + # allow calling this multiple times as long as the path is the same + return + if self.state is not WriterState.NO_FILE: + raise ValueError(f'Expected output file to be not yet opened, got {self.state}') + + if path is not None: + self.path = path + + if self.path is not None: + if self.fout is not None: + self.fout.close() + self.fout = open(self.path, "wb") + self.state = WriterState.EMPTY + + def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None: + self.open_output_file(path) + if self.state is not WriterState.EMPTY: raise ValueError(f'Expected output file to be empty, got {self.state}') self._write_packed(" None: if self.state is not WriterState.HEADER: raise ValueError(f'Expected output file to contain the header, got {self.state}') + assert self.fout is not None + + kv_data = bytearray() + + for key, val in self.kv_data.items(): + kv_data += self._pack_val(key, GGUFValueType.STRING, add_vtype=False) + kv_data += self._pack_val(val.value, val.type, add_vtype=True) - self.fout.write(self.kv_data) + self.fout.write(kv_data) self.flush() self.state = WriterState.KV_DATA def write_ti_data_to_file(self) -> None: if self.state is not WriterState.KV_DATA: raise ValueError(f'Expected output file to contain KV data, got {self.state}') - - self.fout.write(self.ti_data) + assert self.fout is not None + + ti_data = bytearray() + offset_tensor = 0 + + for name, ti in self.tensors.items(): + ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False) + n_dims = len(ti.shape) + ti_data += self._pack("I", n_dims) + for i in range(n_dims): + ti_data += self._pack("Q", ti.shape[n_dims - 1 - i]) + ti_data += self._pack("I", ti.dtype) + ti_data += self._pack("Q", offset_tensor) + offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment) + + self.fout.write(ti_data) self.flush() self.state = WriterState.TI_DATA - def add_key(self, key: str) -> None: - self.add_val(key, GGUFValueType.STRING, add_vtype=False) + def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None: + if key in self.kv_data: + raise ValueError(f'Duplicated key name {key!r}') + + self.kv_data[key] = GGUFValue(value=val, type=vtype) def add_uint8(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.UINT8) + self.add_key_value(key,val, GGUFValueType.UINT8) def add_int8(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.INT8) + self.add_key_value(key, val, GGUFValueType.INT8) def add_uint16(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.UINT16) + self.add_key_value(key, val, GGUFValueType.UINT16) def add_int16(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.INT16) + self.add_key_value(key, val, GGUFValueType.INT16) def add_uint32(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.UINT32) + self.add_key_value(key, val, GGUFValueType.UINT32) def add_int32(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.INT32) + self.add_key_value(key, val, GGUFValueType.INT32) def add_float32(self, key: str, val: float) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.FLOAT32) + self.add_key_value(key, val, GGUFValueType.FLOAT32) def add_uint64(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.UINT64) + self.add_key_value(key, val, GGUFValueType.UINT64) def add_int64(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.INT64) + self.add_key_value(key, val, GGUFValueType.INT64) def add_float64(self, key: str, val: float) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.FLOAT64) + self.add_key_value(key, val, GGUFValueType.FLOAT64) def add_bool(self, key: str, val: bool) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.BOOL) + self.add_key_value(key, val, GGUFValueType.BOOL) def add_string(self, key: str, val: str) -> None: if not val: return - self.add_key(key) - self.add_val(val, GGUFValueType.STRING) + self.add_key_value(key, val, GGUFValueType.STRING) def add_array(self, key: str, val: Sequence[Any]) -> None: if not isinstance(val, Sequence): raise ValueError("Value must be a sequence for array type") - self.add_key(key) - self.add_val(val, GGUFValueType.ARRAY) - - def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None: - if vtype is None: - vtype = GGUFValueType.get_type(val) - - if add_vtype: - self.kv_data += self._pack("I", vtype) - self.kv_data_count += 1 - - pack_fmt = self._simple_value_packing.get(vtype) - if pack_fmt is not None: - self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL) - elif vtype == GGUFValueType.STRING: - encoded_val = val.encode("utf-8") if isinstance(val, str) else val - self.kv_data += self._pack("Q", len(encoded_val)) - self.kv_data += encoded_val - elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val: - ltype = GGUFValueType.get_type(val[0]) - if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]): - raise ValueError("All items in a GGUF array should be of the same type") - self.kv_data += self._pack("I", ltype) - self.kv_data += self._pack("Q", len(val)) - for item in val: - self.add_val(item, add_vtype=False) - else: - raise ValueError("Invalid GGUF metadata value type or value") + self.add_key_value(key, val, GGUFValueType.ARRAY) @staticmethod def ggml_pad(x: int, n: int) -> int: @@ -200,16 +219,12 @@ def add_tensor_info( self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype, tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None, ) -> None: - if self.state is not WriterState.EMPTY: - raise ValueError(f'Expected output file to be empty, got {self.state}') + if self.state is not WriterState.NO_FILE: + raise ValueError(f'Expected output file to be not yet opened, got {self.state}') - if name in self.ti_names: - raise ValueError(f'Duplicated tensor name {name}') - self.ti_names.add(name) + if name in self.tensors: + raise ValueError(f'Duplicated tensor name {name!r}') - encoded_name = name.encode("utf-8") - self.ti_data += self._pack("Q", len(encoded_name)) - self.ti_data += encoded_name if raw_dtype is None: if tensor_dtype == np.float16: dtype = GGMLQuantizationType.F16 @@ -231,14 +246,8 @@ def add_tensor_info( dtype = raw_dtype if tensor_dtype == np.uint8: tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype) - n_dims = len(tensor_shape) - self.ti_data += self._pack("I", n_dims) - for i in range(n_dims): - self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) - self.ti_data += self._pack("I", dtype) - self.ti_data += self._pack("Q", self.offset_tensor) - self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) - self.ti_data_count += 1 + + self.tensors[name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes) def add_tensor( self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, @@ -252,10 +261,10 @@ def add_tensor( self.temp_file = fp shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape - self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) + self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype) if self.temp_file is None: - self.tensors.append(tensor) + self.tensors[name].tensor = tensor return tensor.tofile(self.temp_file) @@ -267,8 +276,9 @@ def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None fp.write(bytes([0] * pad)) def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: - if self.state is not WriterState.TI_DATA: - raise ValueError(f'Expected output file to contain tensor info, got {self.state}') + if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS: + raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') + assert self.fout is not None if self.endianess == GGUFEndian.BIG: tensor.byteswap(inplace=True) @@ -276,50 +286,51 @@ def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: tensor.tofile(self.fout) self.write_padding(self.fout, tensor.nbytes) + self.state = WriterState.WEIGHTS + def write_tensors_to_file(self, *, progress: bool = False) -> None: self.write_ti_data_to_file() + assert self.fout is not None + self.write_padding(self.fout, self.fout.tell()) if self.temp_file is None: - self.tensors.reverse() # to pop from the "beginning" in constant time + bar = None if progress: from tqdm import tqdm - total_bytes = sum(t.nbytes for t in self.tensors) + total_bytes = sum(t.nbytes for t in self.tensors.values()) bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) - while True: - try: - tensor = self.tensors.pop() - except IndexError: - break - tensor.tofile(self.fout) - bar.update(tensor.nbytes) - self.write_padding(self.fout, tensor.nbytes) - return - while True: - try: - tensor = self.tensors.pop() - except IndexError: - break - tensor.tofile(self.fout) - self.write_padding(self.fout, tensor.nbytes) - return + # relying on the fact that Python dicts preserve insertion order (since 3.7) + for ti in self.tensors.values(): + assert ti.tensor is not None # can only iterate once over the tensors + assert ti.tensor.nbytes == ti.nbytes + ti.tensor.tofile(self.fout) + if bar is not None: + bar.update(ti.nbytes) + self.write_padding(self.fout, ti.nbytes) + ti.tensor = None + else: + self.temp_file.seek(0) - self.temp_file.seek(0) + shutil.copyfileobj(self.temp_file, self.fout) + self.flush() + self.temp_file.close() - shutil.copyfileobj(self.temp_file, self.fout) - self.flush() - self.temp_file.close() + self.state = WriterState.WEIGHTS def flush(self) -> None: + assert self.fout is not None self.fout.flush() def close(self) -> None: - self.fout.close() + if self.fout is not None: + self.fout.close() + self.fout = None def add_architecture(self) -> None: self.add_string(Keys.General.ARCHITECTURE, self.arch) @@ -449,7 +460,7 @@ def add_rope_scaling_type(self, value: RopeScalingType) -> None: def add_rope_scaling_factor(self, value: float) -> None: self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value) - def add_rope_scaling_attn_factors(self, value: Sequence[float]) -> None: + def add_rope_scaling_attn_factors(self, value: float) -> None: self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value) def add_rope_scaling_orig_ctx_len(self, value: int) -> None: @@ -571,5 +582,32 @@ def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>' return struct.pack(f'{pack_prefix}{fmt}', value) + def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool) -> bytes: + kv_data = bytearray() + + if add_vtype: + kv_data += self._pack("I", vtype) + + pack_fmt = self._simple_value_packing.get(vtype) + if pack_fmt is not None: + kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL) + elif vtype == GGUFValueType.STRING: + encoded_val = val.encode("utf-8") if isinstance(val, str) else val + kv_data += self._pack("Q", len(encoded_val)) + kv_data += encoded_val + elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val: + ltype = GGUFValueType.get_type(val[0]) + if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]): + raise ValueError("All items in a GGUF array should be of the same type") + kv_data += self._pack("I", ltype) + kv_data += self._pack("Q", len(val)) + for item in val: + kv_data += self._pack_val(item, ltype, add_vtype=False) + else: + raise ValueError("Invalid GGUF metadata value type or value") + + return kv_data + def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None: + assert self.fout is not None self.fout.write(self._pack(fmt, value, skip_pack_prefix)) diff --git a/gguf-py/scripts/gguf-new-metadata.py b/gguf-py/scripts/gguf-new-metadata.py index 21e91180cd340..c4b90d5810a65 100755 --- a/gguf-py/scripts/gguf-new-metadata.py +++ b/gguf-py/scripts/gguf-new-metadata.py @@ -101,8 +101,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new logger.debug(f'Copying {field.name}') if val.value is not None: - writer.add_key(field.name) - writer.add_val(val.value, val.type) + writer.add_key_value(field.name, val.value, val.type) if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata: logger.debug('Adding chat template(s)') @@ -111,8 +110,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new for key, val in new_metadata.items(): logger.debug(f'Adding {key}: "{val.value}" {val.description}') - writer.add_key(key) - writer.add_val(val.value, val.type) + writer.add_key_value(key, val.value, val.type) total_bytes = 0