From f957027b9aa354b28e9d58c65c1beb5bfa98595d Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Sat, 28 Sep 2024 12:08:43 +0000 Subject: [PATCH] llama : add support for Chameleon (#8543) * convert chameleon hf to gguf * add chameleon tokenizer tests * fix lint * implement chameleon graph * add swin norm param * return qk norm weights and biases to original format * implement swin norm * suppress image token output * rem tabs * add comment to conversion * fix ci * check for k norm separately * adapt to new lora implementation * fix layer input for swin norm * move swin_norm in gguf writer * add comment regarding special token regex in chameleon pre-tokenizer * Update src/llama.cpp Co-authored-by: compilade * fix punctuation regex in chameleon pre-tokenizer (@compilade) Co-authored-by: compilade * fix lint * trigger ci --------- Co-authored-by: compilade --- convert_hf_to_gguf.py | 44 +++++ convert_hf_to_gguf_update.py | 1 + gguf-py/gguf/constants.py | 19 ++ gguf-py/gguf/gguf_writer.py | 3 + gguf-py/gguf/tensor_mapping.py | 4 +- include/llama.h | 1 + models/ggml-vocab-chameleon.gguf.inp | 112 ++++++++++++ models/ggml-vocab-chameleon.gguf.out | 46 +++++ src/llama-vocab.cpp | 14 ++ src/llama.cpp | 263 +++++++++++++++++++++++++++ 10 files changed, 505 insertions(+), 2 deletions(-) create mode 100644 models/ggml-vocab-chameleon.gguf.inp create mode 100644 models/ggml-vocab-chameleon.gguf.out diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7be609054d6b8..2cd5a8c11bc18 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -640,6 +640,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "fcace8b9cac38ce847670c970cd5892031a753a1ef381abd1d9af00f713da085": # ref: https://huggingface.co/microsoft/phi-2 res = "phi-2" + if chkhsh == "60824e3c0d9401f89943cbb2fff727f0e2d4c545ba4df2d6e4f09a6db0f5b450": + # ref: https://huggingface.co/facebook/chameleon-7b + res = "chameleon" if res is None: logger.warning("\n") @@ -4138,6 +4141,47 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return super().modify_tensors(data_torch, name, bid) +@Model.register("ChameleonForCausalLM") +class ChameleonModel(Model): + model_arch = gguf.MODEL_ARCH.CHAMELEON + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_swin_norm(self.hparams.get("swin_norm", False)) + + def set_vocab(self): + self._set_vocab_gpt2() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # ignore image tokenizer for now + # TODO: remove this once image support is implemented for Chameleon + if name.startswith("model.vqmodel"): + return [] + + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + hidden_dim = self.hparams.get("hidden_size") + + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + if name.endswith(("q_norm.weight", "q_norm.bias")): + data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_head, hidden_dim) + if name.endswith(("k_norm.weight", "k_norm.bias")): + data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_kv_head, hidden_dim) + + return [(self.map_tensor_name(name), data_torch)] + + # see: https://github.com/huggingface/transformers/blob/72fb02c47dbbe1999ae105319f24631cad6e2e00/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py#L176-L203 + @staticmethod + def _reverse_hf_permute(data_torch, n_heads, hidden_dim): + head_dim = hidden_dim // n_heads + data_torch = data_torch[0].view(2, head_dim // 2).t().reshape(1, -1) + data_torch = data_torch.repeat_interleave(n_heads, 0) + return data_torch + + ###### CONVERSION LOGIC ###### diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 021f65abdc45d..4d11059f374d2 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -99,6 +99,7 @@ class TOKENIZER_TYPE(IntEnum): {'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", }, {"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", }, {"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", }, + {"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", }, ] diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 560eee916f27e..2fd2e9d2be828 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -94,6 +94,7 @@ class LLM: DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" + SWIN_NORM = "{arch}.swin_norm" RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers" TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim" TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim" @@ -236,6 +237,7 @@ class MODEL_ARCH(IntEnum): EXAONE = auto() GRANITE = auto() GRANITE_MOE = auto() + CHAMELEON = auto() class MODEL_TENSOR(IntEnum): @@ -394,6 +396,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.EXAONE: "exaone", MODEL_ARCH.GRANITE: "granite", MODEL_ARCH.GRANITE_MOE: "granitemoe", + MODEL_ARCH.CHAMELEON: "chameleon", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -1260,6 +1263,22 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.CHAMELEON: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index bd059b45c64d0..5c460ef1bc260 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -670,6 +670,9 @@ def add_expert_shared_count(self, count: int) -> None: def add_expert_weights_scale(self, value: float) -> None: self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) + def add_swin_norm(self, value: bool) -> None: + self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) + def add_rescale_every_n_layers(self, count: int) -> None: self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 4e850726e9ba4..5ef91f11d312f 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -380,7 +380,7 @@ class TensorNameMap: MODEL_TENSOR.ATTN_Q_NORM: ( "language_model.encoder.layers.{bid}.self_attention.q_layernorm", "model.layers.{bid}.self_attn.q_layernorm", # persimmon - "model.layers.{bid}.self_attn.q_norm", # cohere olmoe + "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon "transformer.blocks.{bid}.attn.q_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 "transformer.layers.{bid}.attn.q_norm", # openelm @@ -389,7 +389,7 @@ class TensorNameMap: MODEL_TENSOR.ATTN_K_NORM: ( "language_model.encoder.layers.{bid}.self_attention.k_layernorm", "model.layers.{bid}.self_attn.k_layernorm", # persimmon - "model.layers.{bid}.self_attn.k_norm", # cohere olmoe + "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon "transformer.blocks.{bid}.attn.k_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 "transformer.layers.{bid}.attn.k_norm", # openelm diff --git a/include/llama.h b/include/llama.h index 132937a0700e7..caef0bfff0b7d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -102,6 +102,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, }; enum llama_rope_type { diff --git a/models/ggml-vocab-chameleon.gguf.inp b/models/ggml-vocab-chameleon.gguf.inp new file mode 100644 index 0000000000000..9baf7d77ae6b5 --- /dev/null +++ b/models/ggml-vocab-chameleon.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Führer +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-chameleon.gguf.out b/models/ggml-vocab-chameleon.gguf.out new file mode 100644 index 0000000000000..7c5413fee0adf --- /dev/null +++ b/models/ggml-vocab-chameleon.gguf.out @@ -0,0 +1,46 @@ + 17245 16604 16403 16604 33583 18355 + 16421 51153 + + 16604 + 16650 + 16650 16604 + 16581 + 16582 + 16582 16582 + 16582 16582 16582 + 16581 16582 + 31596 17394 + 34926 17394 + 31596 18671 + 34926 18671 + 34926 18671 16384 + 31596 16395 17394 16384 + 34926 16395 17394 16384 + 16811 16704 20410 16483 16631 16397 52854 + 16470 16399 16403 16407 16604 16406 35764 38185 51595 22592 26639 + 29479 23955 17012 20103 25527 27670 17408 19005 21473 24774 + 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 21954 16607 21954 16633 21954 16611 29409 16607 21954 16615 + 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 16604 16391 24664 17153 57169 16721 16872 17073 17304 28729 16392 + 31596 + 34926 + 16650 31596 + 16650 34926 + 16696 31596 + 16696 31596 16582 16696 31596 + 16604 16391 + 16582 16604 16412 + 16390 22623 + 31596 16395 16712 16390 16828 16384 17674 16769 16732 23686 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 + 16384 16384 16384 16384 16384 16384 + 16402 + 16402 16402 + 16402 16402 16402 + 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 16402 16402 16402 + 16418 19038 16639 16448 24315 33727 16467 + 18765 17981 + 16582 16604 16582 16582 16604 16582 16582 16582 16604 16581 16604 16581 16581 16604 16581 16582 16650 16582 16650 16604 16582 16696 16582 16696 16604 16582 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 20410 16483 16631 18885 16483 16631 16604 16402 16604 16402 16402 16604 16402 16402 16402 16604 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16402 16604 16402 16397 16402 16604 16402 16397 16397 16402 16604 16402 16397 16397 16397 16402 16604 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 27683 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 16604 16396 16396 16396 16396 16396 16396 16412 16412 16412 16412 16412 16412 16412 27268 23955 17012 20103 25527 27670 17408 19005 21473 24774 16604 16390 16390 16390 16390 16390 16390 16447 16447 16447 16447 16447 16447 16447 16385 16385 16385 16385 16397 16397 16397 16397 16397 16397 16384 16384 16384 16384 16384 16384 16414 16414 16414 16414 16414 16414 16687 16390 16690 16992 16604 16390 61797 16733 16390 16466 16986 16395 16604 16390 17879 16732 17811 16414 16604 16390 16428 16804 17811 16687 16390 16683 17190 16728 16395 16604 16390 16419 16732 16945 16991 25251 16414 17119 16390 38127 16641 16390 16459 16427 diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a771eccda3017..146d416f770f2 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -450,6 +450,20 @@ struct llm_tokenizer_bpe { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_CHAMELEON: + // Note: in theory, the special token (sentinel and image token) regex_exprs below + // are unnecessary, as they are split in `tokenizer_st_partition` anyway. + // However, since the upstream pre-tokenizer uses them, they are also + // included here (see https://huggingface.co/facebook/chameleon-7b). + regex_exprs = { + "", // Sentinel tokens + "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens + "([\\t\\n]| | )", // directly from tokenizer.json + "\\p{N}", // Individual digits + "[\\p{P}!-/:-@\\[-`{-~]", // Punctuation, Isolated + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { diff --git a/src/llama.cpp b/src/llama.cpp index 8853fd513e892..42d234c15f230 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -216,6 +216,7 @@ enum llm_arch { LLM_ARCH_RWKV6, LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, + LLM_ARCH_CHAMELEON, LLM_ARCH_UNKNOWN, }; @@ -268,6 +269,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" }, + { LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -304,6 +306,7 @@ enum llm_kv { LLM_KV_DECODER_START_TOKEN_ID, LLM_KV_ATTN_LOGIT_SOFTCAPPING, LLM_KV_FINAL_LOGIT_SOFTCAPPING, + LLM_KV_SWIN_NORM, LLM_KV_RESCALE_EVERY_N_LAYERS, LLM_KV_TIME_MIX_EXTRA_DIM, LLM_KV_TIME_DECAY_EXTRA_DIM, @@ -411,6 +414,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, + { LLM_KV_SWIN_NORM, "%s.swin_norm" }, { LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" }, { LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" }, { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" }, @@ -1499,6 +1503,25 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_CHAMELEON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2362,6 +2385,7 @@ struct llama_hparams { bool vocab_only; bool rope_finetuned; bool use_par_res; + bool swin_norm; uint32_t n_vocab; uint32_t n_ctx_train; // context size the model was trained on @@ -6084,6 +6108,18 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_CHAMELEON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default + ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 48: model.type = e_model::MODEL_34B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -6341,6 +6377,11 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "exaone") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_EXAONE; + } else if ( + tokenizer_pre == "chameleon") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; + vocab.tokenizer_add_bos = true; + vocab.tokenizer_clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -8732,6 +8773,45 @@ static bool llm_load_tensors( } } break; + case LLM_ARCH_CHAMELEON: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}); + layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -15876,6 +15956,184 @@ struct llm_build_context { return gf; } + + // ref: https://github.com/facebookresearch/chameleon + // based on the original build_llama() function, changes: + // * qk-norm + // * swin-norm + // * removed bias + // * removed MoE + struct ggml_cgraph * build_chameleon() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + if (hparams.swin_norm) { + cur = inpL; + } else { + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + } + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + if (model.layers[il].attn_q_norm) { + Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur) * n_embd_head, + ggml_element_size(Qcur) * n_embd_head * n_head, + 0); + cb(Qcur, "Qcur", il); + + Qcur = llm_build_norm(ctx0, Qcur, hparams, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, cb, il); + cb(Qcur, "Qcur", il); + } + + if (model.layers[il].attn_k_norm) { + Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens, + ggml_element_size(Kcur) * n_embd_head, + ggml_element_size(Kcur) * n_embd_head * n_head_kv, + 0); + cb(Kcur, "Kcur", il); + + Kcur = llm_build_norm(ctx0, Kcur, hparams, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, cb, il); + cb(Kcur, "Kcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, nullptr, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + + if (hparams.swin_norm) { + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + } + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + if (!hparams.swin_norm) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + } + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + if (hparams.swin_norm) { + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output_with_img_logits", -1); + + // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs. + // Needs to be removed once image outputs are supported. + int img_token_end_idx = 8196; + int img_token_start_idx = 4; + int num_img_tokens = img_token_end_idx - img_token_start_idx; + // creates 1d tensor of size num_img_tokens and values -FLT_MAX, + // which ensures that text token values are always at least larger than image token values + struct ggml_tensor * img_logits = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, num_img_tokens); + img_logits = ggml_clamp(ctx0, img_logits, -FLT_MAX, -FLT_MAX); + cb(img_logits, "img_logits", -1); + cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -16136,6 +16394,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_rwkv6(); } break; + case LLM_ARCH_CHAMELEON: + { + result = llm.build_chameleon(); + } break; default: GGML_ABORT("fatal error"); } @@ -19262,6 +19524,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_CHATGLM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_CHAMELEON: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2