From 385c1a8cd48ae87067413f99de6cc2500fbad835 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Mon, 15 Jul 2024 12:03:07 +0200 Subject: [PATCH 01/20] convert chameleon hf to gguf --- convert_hf_to_gguf.py | 25 +++++++++++++++++++++++++ convert_hf_to_gguf_update.py | 1 + gguf-py/gguf/constants.py | 18 ++++++++++++++++++ gguf-py/gguf/tensor_mapping.py | 4 ++-- 4 files changed, 46 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 42dace219f20f..bcf392595e277 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -521,6 +521,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "b53802fb28e26d645c3a310b34bfe07da813026ec7c7716883404d5e0f8b1901": # ref: https://huggingface.co/core42/jais-13b res = "jais" + if chkhsh == "60824e3c0d9401f89943cbb2fff727f0e2d4c545ba4df2d6e4f09a6db0f5b450": + # ref: https://huggingface.co/facebook/chameleon-7b + res = "chameleon" if res is None: logger.warning("\n") @@ -3419,6 +3422,28 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter name = name.removeprefix("transformer.") return [(self.map_tensor_name(name), data_torch)] +@Model.register("ChameleonForCausalLM") +class ChameleonModel(Model): + model_arch = gguf.MODEL_ARCH.CHAMELEON + + def set_vocab(self): + self._set_vocab_gpt2() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # ignore image tokenizer for now + if name.startswith("model.vqmodel"): + return [] + + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + + return [(self.map_tensor_name(name), data_torch)] + ###### CONVERSION LOGIC ###### diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index e4165ae2d977c..5e1d060919fdb 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -91,6 +91,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", }, {"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", }, {"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", }, + {"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", }, ] diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a95a44237e348..bb0d9beb83733 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -166,6 +166,7 @@ class MODEL_ARCH(IntEnum): BITNET = auto() T5 = auto() JAIS = auto() + CHAMELEON = auto() class MODEL_TENSOR(IntEnum): @@ -293,6 +294,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.T5: "t5", MODEL_ARCH.JAIS: "jais", + MODEL_ARCH.CHAMELEON: "chameleon", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -996,6 +998,22 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.CHAMELEON: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 7264240f5e17a..6b03194629b05 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -364,7 +364,7 @@ class TensorNameMap: MODEL_TENSOR.ATTN_Q_NORM: ( "language_model.encoder.layers.{bid}.self_attention.q_layernorm", "model.layers.{bid}.self_attn.q_layernorm", # persimmon - "model.layers.{bid}.self_attn.q_norm", # cohere + "model.layers.{bid}.self_attn.q_norm", # cohere chameleon "transformer.blocks.{bid}.attn.q_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 "transformer.layers.{bid}.attn.q_norm", # openelm @@ -373,7 +373,7 @@ class TensorNameMap: MODEL_TENSOR.ATTN_K_NORM: ( "language_model.encoder.layers.{bid}.self_attention.k_layernorm", "model.layers.{bid}.self_attn.k_layernorm", # persimmon - "model.layers.{bid}.self_attn.k_norm", # cohere + "model.layers.{bid}.self_attn.k_norm", # cohere chameleon "transformer.blocks.{bid}.attn.k_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 "transformer.layers.{bid}.attn.k_norm", # openelm From 568110aab50fc96e5082adc8bd40dd1ca3aa20bf Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Mon, 15 Jul 2024 12:17:04 +0200 Subject: [PATCH 02/20] add chameleon tokenizer tests --- models/ggml-vocab-chameleon.gguf.inp | 112 +++++++++++++++++++++++++++ models/ggml-vocab-chameleon.gguf.out | 46 +++++++++++ 2 files changed, 158 insertions(+) create mode 100644 models/ggml-vocab-chameleon.gguf.inp create mode 100644 models/ggml-vocab-chameleon.gguf.out diff --git a/models/ggml-vocab-chameleon.gguf.inp b/models/ggml-vocab-chameleon.gguf.inp new file mode 100644 index 0000000000000..9baf7d77ae6b5 --- /dev/null +++ b/models/ggml-vocab-chameleon.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Führer +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-chameleon.gguf.out b/models/ggml-vocab-chameleon.gguf.out new file mode 100644 index 0000000000000..7c5413fee0adf --- /dev/null +++ b/models/ggml-vocab-chameleon.gguf.out @@ -0,0 +1,46 @@ + 17245 16604 16403 16604 33583 18355 + 16421 51153 + + 16604 + 16650 + 16650 16604 + 16581 + 16582 + 16582 16582 + 16582 16582 16582 + 16581 16582 + 31596 17394 + 34926 17394 + 31596 18671 + 34926 18671 + 34926 18671 16384 + 31596 16395 17394 16384 + 34926 16395 17394 16384 + 16811 16704 20410 16483 16631 16397 52854 + 16470 16399 16403 16407 16604 16406 35764 38185 51595 22592 26639 + 29479 23955 17012 20103 25527 27670 17408 19005 21473 24774 + 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 21954 16607 21954 16633 21954 16611 29409 16607 21954 16615 + 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 16604 16391 24664 17153 57169 16721 16872 17073 17304 28729 16392 + 31596 + 34926 + 16650 31596 + 16650 34926 + 16696 31596 + 16696 31596 16582 16696 31596 + 16604 16391 + 16582 16604 16412 + 16390 22623 + 31596 16395 16712 16390 16828 16384 17674 16769 16732 23686 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 + 16384 16384 16384 16384 16384 16384 + 16402 + 16402 16402 + 16402 16402 16402 + 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 16402 16402 16402 + 16418 19038 16639 16448 24315 33727 16467 + 18765 17981 + 16582 16604 16582 16582 16604 16582 16582 16582 16604 16581 16604 16581 16581 16604 16581 16582 16650 16582 16650 16604 16582 16696 16582 16696 16604 16582 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 20410 16483 16631 18885 16483 16631 16604 16402 16604 16402 16402 16604 16402 16402 16402 16604 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16402 16604 16402 16397 16402 16604 16402 16397 16397 16402 16604 16402 16397 16397 16397 16402 16604 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 27683 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 16604 16396 16396 16396 16396 16396 16396 16412 16412 16412 16412 16412 16412 16412 27268 23955 17012 20103 25527 27670 17408 19005 21473 24774 16604 16390 16390 16390 16390 16390 16390 16447 16447 16447 16447 16447 16447 16447 16385 16385 16385 16385 16397 16397 16397 16397 16397 16397 16384 16384 16384 16384 16384 16384 16414 16414 16414 16414 16414 16414 16687 16390 16690 16992 16604 16390 61797 16733 16390 16466 16986 16395 16604 16390 17879 16732 17811 16414 16604 16390 16428 16804 17811 16687 16390 16683 17190 16728 16395 16604 16390 16419 16732 16945 16991 25251 16414 17119 16390 38127 16641 16390 16459 16427 From fc09437496dd5b62bab1237ea1699db27a7941a8 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Mon, 15 Jul 2024 12:26:20 +0200 Subject: [PATCH 03/20] fix lint --- convert_hf_to_gguf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bcf392595e277..830917c964efd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3422,6 +3422,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter name = name.removeprefix("transformer.") return [(self.map_tensor_name(name), data_torch)] + @Model.register("ChameleonForCausalLM") class ChameleonModel(Model): model_arch = gguf.MODEL_ARCH.CHAMELEON From 0453f7d114ae7fd8a8c72b86b3944d6372eb20ba Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Mon, 15 Jul 2024 17:51:16 +0200 Subject: [PATCH 04/20] implement chameleon graph --- include/llama.h | 1 + src/llama.cpp | 234 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 235 insertions(+) diff --git a/include/llama.h b/include/llama.h index 3970c3aebcd62..b788ed6a31276 100644 --- a/include/llama.h +++ b/include/llama.h @@ -92,6 +92,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, LLAMA_VOCAB_PRE_TYPE_VIKING = 18, LLAMA_VOCAB_PRE_TYPE_JAIS = 19, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 20, }; // note: these values should be synchronized with ggml_rope diff --git a/src/llama.cpp b/src/llama.cpp index 400a4232beeb0..40bb3abf502aa 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -239,6 +239,7 @@ enum llm_arch { LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_JAIS, + LLM_ARCH_CHAMELEON, LLM_ARCH_UNKNOWN, }; @@ -283,6 +284,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1296,6 +1298,25 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, }, }, + { + LLM_ARCH_CHAMELEON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -5217,6 +5238,17 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_CHAMELEON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 48: model.type = e_model::MODEL_34B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -5451,6 +5483,11 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "jais") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS; + } else if ( + tokenizer_pre == "chameleon") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; + vocab.tokenizer_add_bos = true; + vocab.tokenizer_clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -7498,6 +7535,45 @@ static bool llm_load_tensors( layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); } } break; + case LLM_ARCH_CHAMELEON: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}); + layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -13606,6 +13682,149 @@ struct llm_build_context { return gf; } + + // ref: https://github.com/facebookresearch/chameleon + // based on the original build_llama() function, changes: + // * qk-norm + // * swin-norm (TODO) + // * removed bias + // * removed MoE + struct ggml_cgraph * build_chameleon() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + if (model.layers[il].attn_q_norm) { + Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur) * n_embd_head, + ggml_element_size(Qcur) * n_embd_head * n_head, + 0); + cb(Qcur, "Qcur", il); + Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens, + ggml_element_size(Kcur) * n_embd_head, + ggml_element_size(Kcur) * n_embd_head * n_head_kv, + 0); + cb(Kcur, "Kcur", il); + + Qcur = llm_build_norm(ctx0, Qcur, hparams, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, cb, il); + cb(Qcur, "Qcur", il); + + Kcur = llm_build_norm(ctx0, Kcur, hparams, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, cb, il); + cb(Kcur, "Kcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, nullptr, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -13853,6 +14072,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_jais(); } break; + case LLM_ARCH_CHAMELEON: + { + result = llm.build_chameleon(); + } break; default: GGML_ASSERT(false); } @@ -15457,6 +15680,16 @@ struct llm_tokenizer_bpe { "\\p{N}", }; break; + case LLAMA_VOCAB_PRE_TYPE_CHAMELEON: + regex_exprs = { + "", // Sentinel tokens + "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens + "([\t\n]| | )", // directly from tokenizer.json + "\\p{N}", // Individual digits + "[\\p{P}\\$\\+<=>\\^~\\|`]+", // Punctuation + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -19367,6 +19600,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_CHATGLM: + case LLM_ARCH_CHAMELEON: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 From 654b1b35b3b5ad3a9a6c6555de71745c2703ec6b Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Mon, 15 Jul 2024 21:19:27 +0200 Subject: [PATCH 05/20] add swin norm param --- convert_hf_to_gguf.py | 4 ++++ gguf-py/gguf/constants.py | 1 + gguf-py/gguf/gguf_writer.py | 3 +++ 3 files changed, 8 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 830917c964efd..457cfbfc4489b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3427,6 +3427,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter class ChameleonModel(Model): model_arch = gguf.MODEL_ARCH.CHAMELEON + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_swin_norm(self.hparams.get("swin_norm", False)) + def set_vocab(self): self._set_vocab_gpt2() diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index bb0d9beb83733..0e43994fcde4a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -52,6 +52,7 @@ class LLM: DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" + SWIN_NORM = "{arch}.swin_norm" class Attention: HEAD_COUNT = "{arch}.attention.head_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index cf95541629032..0975499b9a83f 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -701,6 +701,9 @@ def add_middle_token_id(self, id: int) -> None: def add_eot_token_id(self, id: int) -> None: self.add_uint32(Keys.Tokenizer.EOT_ID, id) + def add_swin_norm(self, value: bool) -> None: + self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) + def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: pack_prefix = '' if not skip_pack_prefix: From c460d5c3bb36a8d339689798debc0c68c3dff98f Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Wed, 17 Jul 2024 12:53:56 +0200 Subject: [PATCH 06/20] return qk norm weights and biases to original format --- convert_hf_to_gguf.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 457cfbfc4489b..5ab8037ef4572 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3441,14 +3441,28 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") + hidden_dim = self.hparams.get("hidden_size") if name.endswith(("q_proj.weight", "q_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_head) if name.endswith(("k_proj.weight", "k_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + if name.endswith(("q_norm.weight", "q_norm.bias")): + data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_head, hidden_dim) + if name.endswith(("k_norm.weight", "k_norm.bias")): + data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_kv_head, hidden_dim) return [(self.map_tensor_name(name), data_torch)] + # see: https://github.com/huggingface/transformers/blob/72fb02c47dbbe1999ae105319f24631cad6e2e00/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py#L176-L203 + @staticmethod + def _reverse_hf_permute(data_torch, n_heads, hidden_dim): + head_dim = hidden_dim // n_heads + data_torch = data_torch[0].view(2, head_dim // 2).t().reshape(1, -1) + data_torch = data_torch.repeat_interleave(n_heads, 0) + return data_torch + + ###### CONVERSION LOGIC ###### From 3d3523e4322a871a88244206f250c28637c3b3cd Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Wed, 17 Jul 2024 12:55:47 +0200 Subject: [PATCH 07/20] implement swin norm --- src/llama.cpp | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 40bb3abf502aa..083e8a86ce846 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -320,6 +320,7 @@ enum llm_kv { LLM_KV_DECODER_START_TOKEN_ID, LLM_KV_ATTN_LOGIT_SOFTCAPPING, LLM_KV_FINAL_LOGIT_SOFTCAPPING, + LLM_KV_SWIN_NORM, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -413,6 +414,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, + { LLM_KV_SWIN_NORM, "%s.swin_norm" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -2172,6 +2174,7 @@ struct llama_hparams { bool vocab_only; bool rope_finetuned; bool use_par_res; + bool swin_norm; uint32_t n_vocab; uint32_t n_ctx_train; // context size the model was trained on @@ -5242,6 +5245,7 @@ static void llm_load_hparams( { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default + ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm); switch (hparams.n_layer) { case 32: model.type = e_model::MODEL_7B; break; @@ -7560,7 +7564,7 @@ static bool llm_load_tensors( layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}); layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}); layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); @@ -13686,7 +13690,7 @@ struct llm_build_context { // ref: https://github.com/facebookresearch/chameleon // based on the original build_llama() function, changes: // * qk-norm - // * swin-norm (TODO) + // * swin-norm // * removed bias // * removed MoE struct ggml_cgraph * build_chameleon() { @@ -13714,9 +13718,11 @@ struct llm_build_context { struct ggml_tensor * inpSA = inpL; // norm - cur = llm_build_norm(ctx0, inpL, hparams, + if (!hparams.swin_norm) { + cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); + } cb(cur, "attn_norm", il); // self-attention @@ -13773,6 +13779,12 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + + if (hparams.swin_norm) { + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + } } if (il == n_layer - 1) { @@ -13787,10 +13799,12 @@ struct llm_build_context { cb(ffn_inp, "ffn_inp", il); // feed-forward network - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); + if (!hparams.swin_norm) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + } cur = llm_build_ffn(ctx0, cur, model.layers[il].ffn_up, NULL, NULL, @@ -13800,6 +13814,13 @@ struct llm_build_context { LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); + if (hparams.swin_norm) { + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -14818,6 +14839,7 @@ static int llama_decode_internal( GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size); ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float)); } + } // extract embeddings From 758612a984ac258a002b2061ef04c4c409a6bb22 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Wed, 17 Jul 2024 16:19:13 +0200 Subject: [PATCH 08/20] suppress image token output --- src/llama.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/llama.cpp b/src/llama.cpp index 083e8a86ce846..c409b162e4e7e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13840,6 +13840,15 @@ struct llm_build_context { // lm_head cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output_with_img_logits", -1); + + int img_token_end_idx = 8196; + int img_token_start_idx = 4; + int num_img_tokens = img_token_end_idx - img_token_start_idx; + struct ggml_tensor * img_logits = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, num_img_tokens); + img_logits = ggml_add1(ctx0, img_logits, ggml_new_f32(ctx0, -FLT_MAX)); + cb(img_logits, "img_logits", -1); + cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); From 90766e15e283c6cefbf605ba5285cb158c0e050e Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Wed, 17 Jul 2024 17:18:09 +0200 Subject: [PATCH 09/20] rem tabs --- src/llama.cpp | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index c409b162e4e7e..312e6dafb2f9b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5244,7 +5244,7 @@ static void llm_load_hparams( case LLM_ARCH_CHAMELEON: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default + hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm); switch (hparams.n_layer) { @@ -13718,11 +13718,11 @@ struct llm_build_context { struct ggml_tensor * inpSA = inpL; // norm - if (!hparams.swin_norm) { + if (!hparams.swin_norm) { cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); - } + } cb(cur, "attn_norm", il); // self-attention @@ -13780,11 +13780,11 @@ struct llm_build_context { model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); - if (hparams.swin_norm) { + if (hparams.swin_norm) { cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); - } + } } if (il == n_layer - 1) { @@ -13799,12 +13799,12 @@ struct llm_build_context { cb(ffn_inp, "ffn_inp", il); // feed-forward network - if (!hparams.swin_norm) { + if (!hparams.swin_norm) { cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - } + } cur = llm_build_ffn(ctx0, cur, model.layers[il].ffn_up, NULL, NULL, @@ -13814,12 +13814,12 @@ struct llm_build_context { LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); - if (hparams.swin_norm) { + if (hparams.swin_norm) { cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - } + } cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -13842,13 +13842,15 @@ struct llm_build_context { cur = ggml_mul_mat(ctx0, model.output, cur); cb(cur, "result_output_with_img_logits", -1); - int img_token_end_idx = 8196; - int img_token_start_idx = 4; - int num_img_tokens = img_token_end_idx - img_token_start_idx; - struct ggml_tensor * img_logits = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, num_img_tokens); - img_logits = ggml_add1(ctx0, img_logits, ggml_new_f32(ctx0, -FLT_MAX)); - cb(img_logits, "img_logits", -1); - cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx); + // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs. + // Needs to be removed once image outputs are supported. + int img_token_end_idx = 8196; + int img_token_start_idx = 4; + int num_img_tokens = img_token_end_idx - img_token_start_idx; + struct ggml_tensor * img_logits = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, num_img_tokens); + img_logits = ggml_add1(ctx0, img_logits, ggml_new_f32(ctx0, -FLT_MAX)); + cb(img_logits, "img_logits", -1); + cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -15713,8 +15715,8 @@ struct llm_tokenizer_bpe { break; case LLAMA_VOCAB_PRE_TYPE_CHAMELEON: regex_exprs = { - "", // Sentinel tokens - "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens + "", // Sentinel tokens + "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens "([\t\n]| | )", // directly from tokenizer.json "\\p{N}", // Individual digits "[\\p{P}\\$\\+<=>\\^~\\|`]+", // Punctuation From 126201d1a24d8e864f6efc3675c2bf7ab01251e0 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Thu, 18 Jul 2024 08:28:27 +0200 Subject: [PATCH 10/20] add comment to conversion --- convert_hf_to_gguf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c66a2b4e6bb0d..cdab8b58e77a1 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3432,6 +3432,7 @@ def set_vocab(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # ignore image tokenizer for now + # TODO: remove this once image support is implemented for Chameleon if name.startswith("model.vqmodel"): return [] From da5e356dfb4bcc275288ccad4faf21cc23bcb7fd Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Thu, 18 Jul 2024 09:28:26 +0200 Subject: [PATCH 11/20] fix ci --- src/llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 725d24026125e..138294600d26c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13896,7 +13896,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); @@ -13926,7 +13926,7 @@ struct llm_build_context { cb(cur, "ffn_norm", il); } - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, From fa568f6a8293327f448a0aa065946b26130ea825 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Thu, 18 Jul 2024 09:30:40 +0200 Subject: [PATCH 12/20] check for k norm separately --- src/llama.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 138294600d26c..2a1556ea88145 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13863,17 +13863,20 @@ struct llm_build_context { ggml_element_size(Qcur) * n_embd_head * n_head, 0); cb(Qcur, "Qcur", il); - Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens, - ggml_element_size(Kcur) * n_embd_head, - ggml_element_size(Kcur) * n_embd_head * n_head_kv, - 0); - cb(Kcur, "Kcur", il); Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, cb, il); cb(Qcur, "Qcur", il); + } + + if (model.layers[il].attn_k_norm) { + Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens, + ggml_element_size(Kcur) * n_embd_head, + ggml_element_size(Kcur) * n_embd_head * n_head_kv, + 0); + cb(Kcur, "Kcur", il); Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, From f40cd2073a27390f09127b63a81abf5bcf433ff8 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Thu, 18 Jul 2024 09:59:30 +0200 Subject: [PATCH 13/20] adapt to new lora implementation --- src/llama.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 2a1556ea88145..b65d589d14d7d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13848,13 +13848,13 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].attn_q_norm) { @@ -13962,7 +13962,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output_with_img_logits", -1); // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs. From 15260c5ba855aeadd4d07331a1682f24f19562c3 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Thu, 18 Jul 2024 10:26:34 +0200 Subject: [PATCH 14/20] fix layer input for swin norm --- src/llama.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index b65d589d14d7d..6d33b1edddcf6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13838,12 +13838,14 @@ struct llm_build_context { struct ggml_tensor * inpSA = inpL; // norm - if (!hparams.swin_norm) { + if (hparams.swin_norm) { + cur = inpL; + } else { cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); } - cb(cur, "attn_norm", il); // self-attention { From 6e0ded3637eba2ba1c783addfb2b96e58a0db931 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Mon, 22 Jul 2024 13:22:21 +0200 Subject: [PATCH 15/20] move swin_norm in gguf writer --- gguf-py/gguf/gguf_writer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index f3261ddac9fc3..af9c063fb2f09 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -546,6 +546,9 @@ def add_expert_shared_count(self, count: int) -> None: def add_expert_weights_scale(self, value: float) -> None: self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) + def add_swin_norm(self, value: bool) -> None: + self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) + def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) @@ -704,9 +707,6 @@ def add_middle_token_id(self, id: int) -> None: def add_eot_token_id(self, id: int) -> None: self.add_uint32(Keys.Tokenizer.EOT_ID, id) - def add_swin_norm(self, value: bool) -> None: - self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) - def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: pack_prefix = '' if not skip_pack_prefix: From 05f138551fd2d8e5223f0d934dda110bbd373ff1 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Mon, 22 Jul 2024 13:44:24 +0200 Subject: [PATCH 16/20] add comment regarding special token regex in chameleon pre-tokenizer --- src/llama.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/llama.cpp b/src/llama.cpp index 6d33b1edddcf6..d19c5cf8f4174 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -15843,6 +15843,10 @@ struct llm_tokenizer_bpe { }; break; case LLAMA_VOCAB_PRE_TYPE_CHAMELEON: + // Note: in theory, the special token (sentinel and image token) regex_exprs below + // are unnecessary, as they are split in `tokenizer_st_partition` anyway. + // However, since the upstream pre-tokenizer uses them, they are also + // included here (see https://huggingface.co/facebook/chameleon-7b). regex_exprs = { "", // Sentinel tokens "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens From 1e1e78a3246cdd23513f0b56795b6db9324bea18 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Mon, 22 Jul 2024 11:46:18 +0000 Subject: [PATCH 17/20] Update src/llama.cpp Co-authored-by: compilade --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index d19c5cf8f4174..db41ad6291f66 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -15850,7 +15850,7 @@ struct llm_tokenizer_bpe { regex_exprs = { "", // Sentinel tokens "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens - "([\t\n]| | )", // directly from tokenizer.json + "([\\t\\n]| | )", // directly from tokenizer.json "\\p{N}", // Individual digits "[\\p{P}\\$\\+<=>\\^~\\|`]+", // Punctuation "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", From 0ee896d149ec02ffe472a467a33709c2acce9f91 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Mon, 22 Jul 2024 11:47:35 +0000 Subject: [PATCH 18/20] fix punctuation regex in chameleon pre-tokenizer (@compilade) Co-authored-by: compilade --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index db41ad6291f66..18aea5ad2cd2d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -15852,7 +15852,7 @@ struct llm_tokenizer_bpe { "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens "([\\t\\n]| | )", // directly from tokenizer.json "\\p{N}", // Individual digits - "[\\p{P}\\$\\+<=>\\^~\\|`]+", // Punctuation + "[\\p{P}!-/:-@\\[-`{-~]", // Punctuation, Isolated "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; break; From 98ea5e704ca7912662f4d5d8e92ba2125022baff Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Mon, 22 Jul 2024 15:17:40 +0200 Subject: [PATCH 19/20] fix lint --- src/llama.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 7c1448397ed23..9d887fefbada3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -15849,10 +15849,10 @@ struct llm_tokenizer_bpe { }; break; case LLAMA_VOCAB_PRE_TYPE_CHAMELEON: - // Note: in theory, the special token (sentinel and image token) regex_exprs below - // are unnecessary, as they are split in `tokenizer_st_partition` anyway. - // However, since the upstream pre-tokenizer uses them, they are also - // included here (see https://huggingface.co/facebook/chameleon-7b). + // Note: in theory, the special token (sentinel and image token) regex_exprs below + // are unnecessary, as they are split in `tokenizer_st_partition` anyway. + // However, since the upstream pre-tokenizer uses them, they are also + // included here (see https://huggingface.co/facebook/chameleon-7b). regex_exprs = { "", // Sentinel tokens "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens From d38a92875192dc2bb10b9dba5f4f50ee4ca64b8f Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Wed, 14 Aug 2024 10:24:04 +0200 Subject: [PATCH 20/20] trigger ci