Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support glm3 and glm4. #8031

Merged
merged 39 commits into from
Jul 7, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6630a2d
add chatglm3-6b model support huggingface model:
xingxingqiao May 29, 2024
5a914ff
remove .rotary_pos_emb.inv_freq and unuse code for chatglm3 model
xingxingqiao May 15, 2024
f626b71
fix lint error
xingxingqiao May 24, 2024
f3bc337
optimize convert-hf-to-gguf.py for chatglm model
xingxingqiao May 16, 2024
1fc5bf5
support glm-4-9b-chat
xingxingqiao Jun 17, 2024
8c5f1b2
fix eos tokens to glm4
youth123 Jun 20, 2024
95fd910
remove unused log
youth123 Jun 20, 2024
e773174
Fix eos tokens to glm4 and adapts to glm3
youth123 Jun 20, 2024
4b65b64
add preprocess to chatglm3 and chatglm4
youth123 Jun 21, 2024
3a4d579
add eos_id_list to llama.cpp
youth123 Jun 24, 2024
9570806
fix conflicts
youth123 Jun 25, 2024
3b67ff8
fix code style
youth123 Jun 25, 2024
5f8f465
fix code style
youth123 Jun 25, 2024
f8d4fc9
fix conflicts
youth123 Jun 25, 2024
a67bc8f
fix conflicts
youth123 Jun 25, 2024
3557944
Merge branch 'glm_support'
youth123 Jun 25, 2024
89e8aaf
Revert "add eos_id_list to llama.cpp"
youth123 Jun 25, 2024
9396c7b
set <|endoftext|> as eos and <|user|> as eot
youth123 Jun 26, 2024
e18a536
Merge remote-tracking branch 'offical/master'
youth123 Jun 26, 2024
0595f03
fix chat template bug
youth123 Jun 26, 2024
7357273
add comment to glm prefix and suffix
youth123 Jun 27, 2024
1dc8e91
Merge remote-tracking branch 'offical/master'
youth123 Jun 27, 2024
e9e47eb
fix conflicts and add rope_ratio & ChatGLMForConditionalGeneration
youth123 Jun 27, 2024
482bdea
merge master
youth123 Jun 28, 2024
bbe1926
fix chat template bug
youth123 Jun 28, 2024
d07f0a9
fix codestyle
youth123 Jul 1, 2024
0d3a94a
merge master
youth123 Jul 1, 2024
5e9dba6
fix conflicts
youth123 Jul 1, 2024
865dd03
modified the general name of glm model
youth123 Jul 1, 2024
71c8e02
Merge remote-tracking branch 'offical/master'
youth123 Jul 2, 2024
ec89d06
merge master
youth123 Jul 3, 2024
80b381b
fix conflicts
youth123 Jul 3, 2024
bf54db2
remove prefix and suffix
youth123 Jul 3, 2024
bce74d8
use normal glm4 chattempalte & use LLM_FFN_SWIGLU in phi3
youth123 Jul 3, 2024
3be4270
fix: resolve Flake8 errors in `convert-hf-to-gguf.py`
Umpire2018 Jul 5, 2024
ed54a65
Merge pull request #2 from Umpire2018/fix/flake8-error
youth123 Jul 7, 2024
5b760f2
fix rope ratio to solve incorrect answers
youth123 Jul 7, 2024
223eb18
merge master
youth123 Jul 7, 2024
4e85b06
fix by comments
youth123 Jul 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a":
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code
res = "jina-v2-code"
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
res = "chatglm-bpe"
if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
# ref: https://huggingface.co/LumiOpen/Viking-7B
res = "viking"
Expand Down Expand Up @@ -2942,6 +2945,192 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(self.map_tensor_name(name), data_torch)]


@Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration")
class ChatGLMModel(Model):
model_arch = gguf.MODEL_ARCH.CHATGLM

def set_vocab_chatglm3(self):
dir_model = self.dir_model
hparams = self.hparams
tokens: list[bytearray] = []
toktypes: list[int] = []
scores: list[float] = []

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab()))
assert max(tokenizer.get_vocab().values()) < vocab_size
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
print(vocab_size)
print(max(tokenizer.get_vocab().values()))
for token_id in range(vocab_size):
piece = tokenizer._convert_id_to_token(token_id)
if token_id == 0:
piece = "<unk>"
elif token_id == 1:
piece = "<bos>"
elif token_id == 2:
piece = "<eos>"

text = piece.encode("utf-8")
score = 0.0
# Referencing the tokenizer Python implementation(https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py),
# it is only valid if it is less than tokenizer.tokenizer.sp_model.vocab_size()
if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size():
score = tokenizer.tokenizer.sp_model.get_score(token_id)

if len(piece) == 0:
text = f"[PAD{token_id}]".encode("utf-8")

if token_id >= tokenizer.tokenizer.sp_model.vocab_size():
if piece in special_tokens:
# show special tokens in prompt
toktype = SentencePieceTokenTypes.USER_DEFINED
else:
toktype = SentencePieceTokenTypes.UNKNOWN
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
continue

toktype = SentencePieceTokenTypes.NORMAL
if tokenizer.tokenizer.sp_model.is_unknown(token_id):
toktype = SentencePieceTokenTypes.UNKNOWN
elif tokenizer.tokenizer.sp_model.is_control(token_id):
toktype = SentencePieceTokenTypes.CONTROL
elif tokenizer.tokenizer.sp_model.is_unused(token_id):
toktype = SentencePieceTokenTypes.UNUSED
elif tokenizer.tokenizer.sp_model.is_byte(token_id):
toktype = SentencePieceTokenTypes.BYTE

tokens.append(text)
scores.append(score)
toktypes.append(toktype)

self.gguf_writer.add_tokenizer_model("llama")
# glm3 needs prefix and suffix formatted as:
# prompt = "[gMASK]sop<|user|>\n" + prompt + "<|assistant|>"
self.gguf_writer.add_tokenizer_pre("chatglm-spm")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)

@staticmethod
def token_bytes_to_string(b):
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
byte_encoder = bytes_to_unicode()
return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])

@staticmethod
def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
parts = [bytes([b]) for b in token]
while True:
min_idx = None
min_rank = None
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
rank = mergeable_ranks.get(pair[0] + pair[1])
if rank is not None and (min_rank is None or rank < min_rank):
min_idx = i
min_rank = rank
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
break
assert min_idx is not None
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
return parts

def set_vocab(self):
if "THUDM/chatglm3-6b" in self.hparams.get("_name_or_path", ""):
self.set_vocab_chatglm3()
return

dir_model = self.dir_model
hparams = self.hparams
tokens: list[str] = []
toktypes: list[int] = []

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
vocab_size = hparams["padded_vocab_size"]
assert max(tokenizer.get_vocab().values()) < vocab_size

tokpre = self.get_vocab_base_pre(tokenizer)

merges = []
vocab = {}
mergeable_ranks = tokenizer.mergeable_ranks
for token, rank in mergeable_ranks.items():
vocab[ChatGLMModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank)
assert len(merged) >= 2 and len(merged) <= 7
merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged)))

# for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
added_vocab = tokenizer.get_added_vocab()
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}

for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)

self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
special_vocab.chat_template = "chatglm4"
special_vocab.merges = merges
# only add special tokens when they were not already loaded from config.json
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
# this one is usually not in config.json anyway
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab.add_to_gguf(self.gguf_writer)

def set_gguf_parameters(self):
self.gguf_writer.add_name(self.hparams.get("_name_or_path").split("/")[1]) # THUDM/glm4-9b-chat or THUDM/chatglm3-6b
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
n_head_kv = self.hparams.get("multi_query_group_num", n_head)
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
self.gguf_writer.add_embedding_length(n_embed)
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", 4 * n_embed))
self.gguf_writer.add_block_count(self.hparams["num_layers"])
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layernorm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_rope_dimension_count(64)
self.gguf_writer.add_add_bos_token(False)
self.gguf_writer.add_rope_freq_base(self.hparams.get("rope_ratio", 10000))


def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if name.endswith(".rotary_pos_emb.inv_freq"):
return []

name = name.removeprefix("transformer.")
return [(self.map_tensor_name(name), data_torch)]


###### CONVERSION LOGIC ######


Expand Down
18 changes: 17 additions & 1 deletion gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ class Tokenizer:
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
EOT_ID = "tokenizer.ggml.eot_token_id"


#
# recommended mapping of model tensor names for storage in gguf
#
Expand Down Expand Up @@ -161,6 +160,7 @@ class MODEL_ARCH(IntEnum):
OLMO = auto()
ARCTIC = auto()
DEEPSEEK2 = auto()
CHATGLM = auto()
BITNET = auto()
T5 = auto()

Expand Down Expand Up @@ -285,6 +285,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.OLMO: "olmo",
MODEL_ARCH.ARCTIC: "arctic",
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
}
Expand Down Expand Up @@ -906,6 +907,18 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
MODEL_ARCH.CHATGLM : [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.BITNET: [
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
Expand Down Expand Up @@ -990,6 +1003,9 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.CHATGLM: [
MODEL_TENSOR.ROPE_FREQS,
],
}

#
Expand Down
16 changes: 13 additions & 3 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class TensorNameMap:
"backbone.embedding", # mamba
"backbone.embeddings", # mamba-hf
"transformer.in_out_embed", # Grok
"embedding.word_embeddings", # chatglm
"shared", # t5
),

Expand Down Expand Up @@ -53,6 +54,7 @@ class TensorNameMap:
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
"output_layer", # chatglm
),

# Output norm
Expand All @@ -69,11 +71,13 @@ class TensorNameMap:
"model.norm_f", # mamba-qbert
"backbone.norm_f", # mamba
"transformer.rms_norm", # Grok
"encoder.final_layernorm", # chatglm
),

# Rope frequencies
MODEL_TENSOR.ROPE_FREQS: (
"rope.freqs", # llama-pth
"rotary_pos_emb.inv_freq", # chatglm
),
}

Expand All @@ -98,6 +102,7 @@ class TensorNameMap:
"backbone.layers.{bid}.norm", # mamba
"transformer.decoder_layer.{bid}.rms_norm", # Grok
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
"encoder.layers.{bid}.input_layernorm", # chatglm
),

# Attention norm 2
Expand All @@ -119,7 +124,8 @@ class TensorNameMap:
"h.{bid}.attn.c_attn", # gpt2
"transformer.h.{bid}.mixer.Wqkv", # phi2
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
"model.layers.{bid}.self_attn.qkv_proj" # phi3
"model.layers.{bid}.self_attn.qkv_proj", # phi3
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
),

# Attention query
Expand All @@ -130,7 +136,7 @@ class TensorNameMap:
"transformer.h.{bid}.attn.q_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
"model.layers.{bid}.attention.wq", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
),

# Attention key
Expand All @@ -142,7 +148,7 @@ class TensorNameMap:
"transformer.h.{bid}.attn.k", # refact
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
"model.layers.{bid}.attention.wk", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
),

# Attention value
Expand Down Expand Up @@ -177,6 +183,7 @@ class TensorNameMap:
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
"encoder.layers.{bid}.self_attention.dense", # chatglm
),

# Attention output norm
Expand Down Expand Up @@ -212,6 +219,7 @@ class TensorNameMap:
"h.{bid}.ln_2", # gpt2
"model.layers.{bid}.ffn_norm", # internlm2
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
),

# Post feed-forward norm
Expand Down Expand Up @@ -261,6 +269,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.c_fc", # starcoder2
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
"model.layers.{bid}.residual_mlp.w3", # arctic
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
),

MODEL_TENSOR.FFN_UP_EXP: (
Expand Down Expand Up @@ -328,6 +337,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.wo", # jina-bert-v2
"model.layers.{bid}.residual_mlp.w2", # arctic
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
),

MODEL_TENSOR.FFN_DOWN_EXP: (
Expand Down
4 changes: 3 additions & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
LLAMA_VOCAB_PRE_TYPE_VIKING = 16,
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
};

// note: these values should be synchronized with ggml_rope
Expand Down
Loading