Skip to content

Commit

Permalink
llama : add support for Chameleon (ggerganov#8543)
Browse files Browse the repository at this point in the history
* convert chameleon hf to gguf

* add chameleon tokenizer tests

* fix lint

* implement chameleon graph

* add swin norm param

* return qk norm weights and biases to original format

* implement swin norm

* suppress image token output

* rem tabs

* add comment to conversion

* fix ci

* check for k norm separately

* adapt to new lora implementation

* fix layer input for swin norm

* move swin_norm in gguf writer

* add comment regarding special token regex in chameleon pre-tokenizer

* Update src/llama.cpp

Co-authored-by: compilade <[email protected]>

* fix punctuation regex in chameleon pre-tokenizer (@compilade)

Co-authored-by: compilade <[email protected]>

* fix lint

* trigger ci

---------

Co-authored-by: compilade <[email protected]>
  • Loading branch information
2 people authored and arthw committed Nov 18, 2024
1 parent b697557 commit 84597bf
Show file tree
Hide file tree
Showing 10 changed files with 505 additions and 2 deletions.
44 changes: 44 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "fcace8b9cac38ce847670c970cd5892031a753a1ef381abd1d9af00f713da085":
# ref: https://huggingface.co/microsoft/phi-2
res = "phi-2"
if chkhsh == "60824e3c0d9401f89943cbb2fff727f0e2d4c545ba4df2d6e4f09a6db0f5b450":
# ref: https://huggingface.co/facebook/chameleon-7b
res = "chameleon"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -4138,6 +4141,47 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return super().modify_tensors(data_torch, name, bid)


@Model.register("ChameleonForCausalLM")
class ChameleonModel(Model):
model_arch = gguf.MODEL_ARCH.CHAMELEON

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_swin_norm(self.hparams.get("swin_norm", False))

def set_vocab(self):
self._set_vocab_gpt2()

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# ignore image tokenizer for now
# TODO: remove this once image support is implemented for Chameleon
if name.startswith("model.vqmodel"):
return []

n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
hidden_dim = self.hparams.get("hidden_size")

if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
if name.endswith(("q_norm.weight", "q_norm.bias")):
data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_head, hidden_dim)
if name.endswith(("k_norm.weight", "k_norm.bias")):
data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_kv_head, hidden_dim)

return [(self.map_tensor_name(name), data_torch)]

# see: https://github.com/huggingface/transformers/blob/72fb02c47dbbe1999ae105319f24631cad6e2e00/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py#L176-L203
@staticmethod
def _reverse_hf_permute(data_torch, n_heads, hidden_dim):
head_dim = hidden_dim // n_heads
data_torch = data_torch[0].view(2, head_dim // 2).t().reshape(1, -1)
data_torch = data_torch.repeat_interleave(n_heads, 0)
return data_torch


###### CONVERSION LOGIC ######


Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class TOKENIZER_TYPE(IntEnum):
{'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", },
{"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", },
{"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", },
{"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", },
]


Expand Down
19 changes: 19 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class LLM:
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
SWIN_NORM = "{arch}.swin_norm"
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim"
TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim"
Expand Down Expand Up @@ -236,6 +237,7 @@ class MODEL_ARCH(IntEnum):
EXAONE = auto()
GRANITE = auto()
GRANITE_MOE = auto()
CHAMELEON = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -394,6 +396,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.GRANITE: "granite",
MODEL_ARCH.GRANITE_MOE: "granitemoe",
MODEL_ARCH.CHAMELEON: "chameleon",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -1260,6 +1263,22 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.CHAMELEON: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
# TODO
}

Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,9 @@ def add_expert_shared_count(self, count: int) -> None:
def add_expert_weights_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)

def add_swin_norm(self, value: bool) -> None:
self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)

def add_rescale_every_n_layers(self, count: int) -> None:
self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count)

Expand Down
4 changes: 2 additions & 2 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
"transformer.layers.{bid}.attn.q_norm", # openelm
Expand All @@ -389,7 +389,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_K_NORM: (
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
"transformer.layers.{bid}.attn.k_norm", # openelm
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
};

enum llama_rope_type {
Expand Down
112 changes: 112 additions & 0 deletions models/ggml-vocab-chameleon.gguf.inp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
ied 4 ½ months
__ggml_vocab_test__
Führer
__ggml_vocab_test__

__ggml_vocab_test__

__ggml_vocab_test__

__ggml_vocab_test__

__ggml_vocab_test__

__ggml_vocab_test__


__ggml_vocab_test__



__ggml_vocab_test__




__ggml_vocab_test__


__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
this is 🦙.cpp
__ggml_vocab_test__
w048 7tuijk dsdfhu
__ggml_vocab_test__
нещо на Български
__ggml_vocab_test__
កាន់តែពិសេសអាចខលចេញ
__ggml_vocab_test__
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
Hello
__ggml_vocab_test__
(
__ggml_vocab_test__

=
__ggml_vocab_test__
' era
__ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
__ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3
__ggml_vocab_test__
33
__ggml_vocab_test__
333
__ggml_vocab_test__
3333
__ggml_vocab_test__
33333
__ggml_vocab_test__
333333
__ggml_vocab_test__
3333333
__ggml_vocab_test__
33333333
__ggml_vocab_test__
333333333
__ggml_vocab_test__
Cửa Việt
__ggml_vocab_test__
discards
__ggml_vocab_test__











🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
__ggml_vocab_test__
46 changes: 46 additions & 0 deletions models/ggml-vocab-chameleon.gguf.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
17245 16604 16403 16604 33583 18355
16421 51153

16604
16650
16650 16604
16581
16582
16582 16582
16582 16582 16582
16581 16582
31596 17394
34926 17394
31596 18671
34926 18671
34926 18671 16384
31596 16395 17394 16384
34926 16395 17394 16384
16811 16704 20410 16483 16631 16397 52854
16470 16399 16403 16407 16604 16406 35764 38185 51595 22592 26639
29479 23955 17012 20103 25527 27670 17408 19005 21473 24774
54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 21954 16607 21954 16633 21954 16611 29409 16607 21954 16615
52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 16604 16391 24664 17153 57169 16721 16872 17073 17304 28729 16392
31596
34926
16650 31596
16650 34926
16696 31596
16696 31596 16582 16696 31596
16604 16391
16582 16604 16412
16390 22623
31596 16395 16712 16390 16828 16384 17674 16769 16732 23686 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636
16384 16384 16384 16384 16384 16384
16402
16402 16402
16402 16402 16402
16402 16402 16402 16402
16402 16402 16402 16402 16402
16402 16402 16402 16402 16402 16402
16402 16402 16402 16402 16402 16402 16402
16402 16402 16402 16402 16402 16402 16402 16402
16402 16402 16402 16402 16402 16402 16402 16402 16402
16418 19038 16639 16448 24315 33727 16467
18765 17981
16582 16604 16582 16582 16604 16582 16582 16582 16604 16581 16604 16581 16581 16604 16581 16582 16650 16582 16650 16604 16582 16696 16582 16696 16604 16582 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 20410 16483 16631 18885 16483 16631 16604 16402 16604 16402 16402 16604 16402 16402 16402 16604 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16402 16604 16402 16397 16402 16604 16402 16397 16397 16402 16604 16402 16397 16397 16397 16402 16604 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 27683 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 16604 16396 16396 16396 16396 16396 16396 16412 16412 16412 16412 16412 16412 16412 27268 23955 17012 20103 25527 27670 17408 19005 21473 24774 16604 16390 16390 16390 16390 16390 16390 16447 16447 16447 16447 16447 16447 16447 16385 16385 16385 16385 16397 16397 16397 16397 16397 16397 16384 16384 16384 16384 16384 16384 16414 16414 16414 16414 16414 16414 16687 16390 16690 16992 16604 16390 61797 16733 16390 16466 16986 16395 16604 16390 17879 16732 17811 16414 16604 16390 16428 16804 17811 16687 16390 16683 17190 16728 16395 16604 16390 16419 16732 16945 16991 25251 16414 17119 16390 38127 16641 16390 16459 16427
14 changes: 14 additions & 0 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,20 @@ struct llm_tokenizer_bpe {
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_CHAMELEON:
// Note: in theory, the special token (sentinel and image token) regex_exprs below
// are unnecessary, as they are split in `tokenizer_st_partition` anyway.
// However, since the upstream pre-tokenizer uses them, they are also
// included here (see https://huggingface.co/facebook/chameleon-7b).
regex_exprs = {
"<sentinel:[0-9]+>", // Sentinel tokens
"(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens
"([\\t\\n]| | )", // directly from tokenizer.json
"\\p{N}", // Individual digits
"[\\p{P}!-/:-@\\[-`{-~]", // Punctuation, Isolated
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
};
break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {
Expand Down
Loading

0 comments on commit 84597bf

Please sign in to comment.