From 6630a2da4863d244725bccb45ae90984beb86172 Mon Sep 17 00:00:00 2001 From: XingXing Qiao Date: Wed, 29 May 2024 13:30:07 +0800 Subject: [PATCH 01/28] add chatglm3-6b model support huggingface model: https://hf-mirror.com/THUDM/chatglm3-6b Signed-off-by: XingXing Qiao --- convert-hf-to-gguf.py | 163 ++++++++++++++++++++++++++- gguf-py/gguf/constants.py | 17 +++ gguf-py/gguf/tensor_mapping.py | 16 ++- gguf-py/pyproject.toml | 2 +- llama.cpp | 200 +++++++++++++++++++++++++++++++++ tests/test-chat-template.cpp | 6 +- 6 files changed, 398 insertions(+), 6 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 98b50d15017d0..ad4bf2274581b 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -79,7 +79,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, if not self.is_safetensors: self.part_names = Model.get_model_part_names(self.dir_model, ".bin") self.hparams = Model.load_hparams(self.dir_model) - self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) + self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) self.tensor_names = None if self.ftype == gguf.LlamaFileType.GUESSED: @@ -2710,6 +2710,167 @@ def write_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("ChatGLMModel") +class ChatGLMModel(Model): + model_arch = gguf.MODEL_ARCH.CHATGLM + + def set_vocab(self): + dir_model = self.dir_model + hparams = self.hparams + tokens: list[bytearray] = [] + toktypes: list[int] = [] + scores: list[float] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) + vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab())) + assert max(tokenizer.get_vocab().values()) < vocab_size + + reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.get_vocab().items()} + + for token_id in range(vocab_size): + piece = tokenizer._convert_id_to_token(token_id) + if token_id == 0: + piece = "" + elif token_id == 1: + piece = "" + elif token_id == 2: + piece = "" + + text = piece.encode("utf-8") + score = 0.0 + if len(piece) != 0 and token_id < 64789: + score = tokenizer.tokenizer.sp_model.get_score(token_id) + + if len(piece) == 0: + text = f"[PAD{token_id}]".encode("utf-8") + + if token_id >= 64789: + toktype = SentencePieceTokenTypes.UNKNOWN + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + continue + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.tokenizer.sp_model.is_unknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.tokenizer.sp_model.is_control(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.tokenizer.sp_model.is_unused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.tokenizer.sp_model.is_byte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + self.gguf_writer.add_name("ChatGLM-6b-chat") + n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) + n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) + n_head_kv = self.hparams.get("multi_query_group_num", n_head) + self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) + self.gguf_writer.add_embedding_length(n_embed) + self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", 4 * n_embed)) + self.gguf_writer.add_block_count(self.hparams["num_layers"]) + self.gguf_writer.add_head_count(n_head) + self.gguf_writer.add_head_count_kv(n_head_kv) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layernorm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_rope_dimension_count(64) + self.gguf_writer.add_add_bos_token(False) + + def write_tensors(self): + block_count = self.hparams["num_layers"] + tensors = dict(self.get_tensors()) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + has_lm_head = True + n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) + n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) + + for name, data_torch in tensors.items(): + if name.endswith(".rotary_pos_emb.inv_freq"): + continue + + if "lm_head.weight" not in tensors.keys() and "output.weight" not in tensors.keys(): + has_lm_head = False + + name = re.sub(r'transformer\.', '', name) + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name): + # Map bloom-style qkv_linear to gpt-style qkv_linear + # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa + # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa + qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed)) + data = np.concatenate( + ( + qkv_weights[:, 0, :, :].reshape((-1, n_embed)), + qkv_weights[:, 1, :, :].reshape((-1, n_embed)), + qkv_weights[:, 2, :, :].reshape((-1, n_embed)), + ), + axis=0, + ) + print("re-format attention.linear_qkv.weight") + elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name): + qkv_bias = data.reshape((n_head, 3, n_embed // n_head)) + data = np.concatenate( + ( + qkv_bias[:, 0, :].reshape((n_embed,)), + qkv_bias[:, 1, :].reshape((n_embed,)), + qkv_bias[:, 2, :].reshape((n_embed,)), + ), + axis=0, + ) + print("re-format attention.linear_qkv.bias") + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + if not has_lm_head and name == "word_embeddings.weight": + self.gguf_writer.add_tensor("output.weight", data) + print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}") + + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 55ec2cb5c848a..73758701c4ee8 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -148,6 +148,7 @@ class MODEL_ARCH(IntEnum): OLMO = auto() ARCTIC = auto() DEEPSEEK2 = auto() + CHATGLM = auto() class MODEL_TENSOR(IntEnum): @@ -236,6 +237,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.OLMO: "olmo", MODEL_ARCH.ARCTIC: "arctic", MODEL_ARCH.DEEPSEEK2: "deepseek2", + MODEL_ARCH.CHATGLM: "chatglm", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -805,6 +807,18 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, ], + MODEL_ARCH.CHATGLM : [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } @@ -842,6 +856,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.CHATGLM: [ + MODEL_TENSOR.ROPE_FREQS, + ], } # diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 83e3c4c3381a0..8a3a823d10275 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -24,6 +24,7 @@ class TensorNameMap: "backbone.embedding", # mamba "backbone.embeddings", # mamba-hf "transformer.in_out_embed", # Grok + "embedding.word_embeddings", # chatglm ), # Token type embeddings @@ -52,6 +53,7 @@ class TensorNameMap: "output", # llama-pth bloom internlm2 "word_embeddings_for_head", # persimmon "lm_head.linear", # phi2 + "output_layer", # chatglm ), # Output norm @@ -68,11 +70,13 @@ class TensorNameMap: "model.norm_f", # mamba-qbert "backbone.norm_f", # mamba "transformer.rms_norm", # Grok + "encoder.final_layernorm", # chatglm ), # Rope frequencies MODEL_TENSOR.ROPE_FREQS: ( "rope.freqs", # llama-pth + "rotary_pos_emb.inv_freq", # chatglm ), } @@ -97,6 +101,7 @@ class TensorNameMap: "backbone.layers.{bid}.norm", # mamba "transformer.decoder_layer.{bid}.rms_norm", # Grok "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx + "encoder.layers.{bid}.input_layernorm", # chatglm ), # Attention norm 2 @@ -117,7 +122,8 @@ class TensorNameMap: "h.{bid}.attn.c_attn", # gpt2 "transformer.h.{bid}.mixer.Wqkv", # phi2 "encoder.layers.{bid}.attn.Wqkv", # nomic-bert - "model.layers.{bid}.self_attn.qkv_proj" # phi3 + "model.layers.{bid}.self_attn.qkv_proj", # phi3 + "encoder.layers.{bid}.self_attention.query_key_value", # chatglm ), # Attention query @@ -128,7 +134,7 @@ class TensorNameMap: "transformer.h.{bid}.attn.q_proj", # gpt-j "model.layers.layers.{bid}.self_attn.q_proj", # plamo "model.layers.{bid}.attention.wq", # internlm2 - "transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok + "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok ), # Attention key @@ -140,7 +146,7 @@ class TensorNameMap: "transformer.h.{bid}.attn.k", # refact "model.layers.layers.{bid}.self_attn.k_proj", # plamo "model.layers.{bid}.attention.wk", # internlm2 - "transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok + "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok ), # Attention value @@ -175,6 +181,7 @@ class TensorNameMap: "encoder.layers.{bid}.attn.out_proj", # nomic-bert "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx + "encoder.layers.{bid}.self_attention.dense", # chatglm ), # Attention output norm @@ -206,6 +213,7 @@ class TensorNameMap: "h.{bid}.ln_2", # gpt2 "model.layers.{bid}.ffn_norm", # internlm2 "transformer.decoder_layer.{bid}.rms_norm_2", # Grok + "encoder.layers.{bid}.post_attention_layernorm", # chatglm ), MODEL_TENSOR.FFN_GATE_INP: ( @@ -245,6 +253,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.c_fc", # starcoder2 "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 "model.layers.{bid}.residual_mlp.w3", # arctic + "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -311,6 +320,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.c_proj", # starcoder2 "encoder.layer.{bid}.mlp.wo", # jina-bert-v2 "model.layers.{bid}.residual_mlp.w2", # arctic + "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm ), MODEL_TENSOR.FFN_DOWN_EXP: ( diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index 36e63ee3b7cd6..62129126bdddc 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gguf" -version = "0.9.0" +version = "0.9.1" description = "Read and write ML models in GGUF for GGML" authors = ["GGML "] packages = [ diff --git a/llama.cpp b/llama.cpp index dac81acc06a92..154168ef3376b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -223,6 +223,7 @@ enum llm_arch { LLM_ARCH_OLMO, LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK2, + LLM_ARCH_CHATGLM, LLM_ARCH_UNKNOWN, }; @@ -261,6 +262,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_OLMO, "olmo" }, { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1108,6 +1110,21 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_CHATGLM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -4486,6 +4503,14 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_CHATGLM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: model.type = e_model::MODEL_7B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -6319,6 +6344,36 @@ static bool llm_load_tensors( } } } break; + case LLM_ARCH_CHATGLM: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)}); + layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + (hparams.n_embd_head_k << 2)}); + + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6553,6 +6608,7 @@ enum llm_ffn_op_type { LLM_FFN_GELU, LLM_FFN_RELU, LLM_FFN_RELU_SQR, + LLM_FFN_SWIGLU, }; enum llm_ffn_gate_type { @@ -6743,6 +6799,19 @@ static struct ggml_tensor * llm_build_ffn( cur = ggml_sqr(ctx, cur); cb(cur, "ffn_sqr(relu)", il); } break; + case LLM_FFN_SWIGLU: + { + // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + int64_t split_point = cur->ne[0] / 2; + struct ggml_tensor * x0 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], 0)); + struct ggml_tensor * x1 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); + + x0 = ggml_silu(ctx, x0); + cb(cur, "ffn_silu", il); + + cur = ggml_mul(ctx, x0, x1); + cb(cur, "ffn_mul", il); + } break; } if (type_gate == LLM_FFN_PAR) { @@ -11334,6 +11403,119 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_chatglm() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + struct ggml_tensor * Qcur = nullptr; + struct ggml_tensor * Kcur = nullptr; + struct ggml_tensor * Vcur = nullptr; + + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur_rope", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur_rope", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Add the input + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + NULL, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + + } + + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, + NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -11556,6 +11738,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_deepseek2(); } break; + case LLM_ARCH_CHATGLM: + { + result = llm.build_chatglm(); + } break; default: GGML_ASSERT(false); } @@ -16550,6 +16736,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_OLMO: case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_CHATGLM: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -18541,6 +18728,19 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; } + } else if (tmpl == "chatglm3" || + (tmpl.find("add_generation_prompt") != std::string::npos && + tmpl.find("for message in messages") != std::string::npos && + tmpl.find("loop.first") != std::string::npos)) { + // chatglm3-6b + ss << "[gMASK]" << "sop"; + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n " << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } } else { // template not supported return -1; diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index cef9a650bdfdf..87f39f1039441 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -56,7 +56,9 @@ int main(void) { //Phi-3-medium "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", //Phi-3-vision - "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}" + "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}", + // ChatGLM3 + "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B @@ -93,6 +95,8 @@ int main(void) { "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", //Phi-3-vision "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + // ChatGLM3 + "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>", }; std::vector formatted_chat(1024); int32_t res; From 5a914ffce07afc4e7d00915e14c627d28e60cc08 Mon Sep 17 00:00:00 2001 From: XingXing Qiao Date: Wed, 15 May 2024 11:00:04 +0800 Subject: [PATCH 02/28] remove .rotary_pos_emb.inv_freq and unuse code for chatglm3 model Signed-off-by: XingXing Qiao --- convert-hf-to-gguf.py | 83 +++++-------------------------------------- 1 file changed, 8 insertions(+), 75 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index ad4bf2274581b..01422eefbf71b 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2790,85 +2790,18 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_dimension_count(64) self.gguf_writer.add_add_bos_token(False) - def write_tensors(self): - block_count = self.hparams["num_layers"] - tensors = dict(self.get_tensors()) - tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) - has_lm_head = True - n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) - n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) - - for name, data_torch in tensors.items(): - if name.endswith(".rotary_pos_emb.inv_freq"): - continue - - if "lm_head.weight" not in tensors.keys() and "output.weight" not in tensors.keys(): - has_lm_head = False + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.endswith(".rotary_pos_emb.inv_freq"): + return [] - name = re.sub(r'transformer\.', '', name) + del bid # unused - old_dtype = data_torch.dtype + name = re.sub(r'transformer\.', '', name) - # convert any unsupported data types to float32 - if data_torch.dtype not in (torch.float16, torch.float32): - data_torch = data_torch.to(torch.float32) + if name == "word_embeddings.weight": + assert self.tensor_names is not None - data = data_torch.squeeze().numpy() - - if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name): - # Map bloom-style qkv_linear to gpt-style qkv_linear - # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa - # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa - qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed)) - data = np.concatenate( - ( - qkv_weights[:, 0, :, :].reshape((-1, n_embed)), - qkv_weights[:, 1, :, :].reshape((-1, n_embed)), - qkv_weights[:, 2, :, :].reshape((-1, n_embed)), - ), - axis=0, - ) - print("re-format attention.linear_qkv.weight") - elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name): - qkv_bias = data.reshape((n_head, 3, n_embed // n_head)) - data = np.concatenate( - ( - qkv_bias[:, 0, :].reshape((n_embed,)), - qkv_bias[:, 1, :].reshape((n_embed,)), - qkv_bias[:, 2, :].reshape((n_embed,)), - ), - axis=0, - ) - print("re-format attention.linear_qkv.bias") - - # map tensor names - new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) - if new_name is None: - print(f"Can not map tensor {name!r}") - sys.exit() - - n_dims = len(data.shape) - data_dtype = data.dtype - - # if f32 desired, convert any float16 to float32 - if self.ftype == 0 and data_dtype == np.float16: - data = data.astype(np.float32) - - # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 - if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: - data = data.astype(np.float32) - - # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: - data = data.astype(np.float16) - - print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}") - - self.gguf_writer.add_tensor(new_name, data) - - if not has_lm_head and name == "word_embeddings.weight": - self.gguf_writer.add_tensor("output.weight", data) - print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}") + return [(self.map_tensor_name(name), data_torch)] ###### CONVERSION LOGIC ###### From f626b7175c4f2938a59d0924ce64abe3869ecc6b Mon Sep 17 00:00:00 2001 From: XingXing Qiao Date: Fri, 24 May 2024 14:13:36 +0800 Subject: [PATCH 03/28] fix lint error Signed-off-by: XingXing Qiao --- convert-hf-to-gguf.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 01422eefbf71b..59e25cb5802d5 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2726,8 +2726,6 @@ def set_vocab(self): vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab())) assert max(tokenizer.get_vocab().values()) < vocab_size - reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.get_vocab().items()} - for token_id in range(vocab_size): piece = tokenizer._convert_id_to_token(token_id) if token_id == 0: From f3bc337f432a5f8d7391bd7af7bacfa55778d210 Mon Sep 17 00:00:00 2001 From: XingXing Qiao Date: Thu, 16 May 2024 11:42:53 +0800 Subject: [PATCH 04/28] optimize convert-hf-to-gguf.py for chatglm model Signed-off-by: XingXing Qiao --- convert-hf-to-gguf.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 59e25cb5802d5..7591da6efe249 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2737,13 +2737,15 @@ def set_vocab(self): text = piece.encode("utf-8") score = 0.0 - if len(piece) != 0 and token_id < 64789: + # Referencing the tokenizer Python implementation(https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py), + # it is only valid if it is less than tokenizer.tokenizer.sp_model.vocab_size() + if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size(): score = tokenizer.tokenizer.sp_model.get_score(token_id) if len(piece) == 0: text = f"[PAD{token_id}]".encode("utf-8") - if token_id >= 64789: + if token_id >= tokenizer.tokenizer.sp_model.vocab_size(): toktype = SentencePieceTokenTypes.UNKNOWN tokens.append(text) scores.append(score) @@ -2773,7 +2775,7 @@ def set_vocab(self): special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): - self.gguf_writer.add_name("ChatGLM-6b-chat") + self.gguf_writer.add_name(self.dir_model.name) n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) n_head_kv = self.hparams.get("multi_query_group_num", n_head) @@ -2789,16 +2791,12 @@ def set_gguf_parameters(self): self.gguf_writer.add_add_bos_token(False) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - if name.endswith(".rotary_pos_emb.inv_freq"): - return [] - del bid # unused - name = re.sub(r'transformer\.', '', name) - - if name == "word_embeddings.weight": - assert self.tensor_names is not None + if name.endswith(".rotary_pos_emb.inv_freq"): + return [] + name = name.removeprefix("transformer.") return [(self.map_tensor_name(name), data_torch)] From 1fc5bf5bcb9b9fbfcec8a4f1768349ca9f28a3fa Mon Sep 17 00:00:00 2001 From: XingXing Qiao Date: Mon, 17 Jun 2024 10:08:52 +0800 Subject: [PATCH 05/28] support glm-4-9b-chat Signed-off-by: XingXing Qiao --- convert-hf-to-gguf.py | 96 ++++++++++++++++++++++++++++++++++-- examples/server/server.cpp | 2 +- llama.cpp | 20 ++++++-- llama.h | 1 + tests/test-chat-template.cpp | 4 ++ 5 files changed, 116 insertions(+), 7 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 7591da6efe249..b36b5193c7881 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -476,6 +476,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "c136ed14d01c2745d4f60a9596ae66800e2b61fa45643e72436041855ad4089d": # ref: https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct res = "smaug-bpe" + if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b": + # ref: https://huggingface.co/THUDM/glm-4-9b-chat + res = "chatglm-bpe" if res is None: logger.warning("\n") @@ -2714,7 +2717,7 @@ def write_tensors(self): class ChatGLMModel(Model): model_arch = gguf.MODEL_ARCH.CHATGLM - def set_vocab(self): + def set_vocab_chatglm3(self): dir_model = self.dir_model hparams = self.hparams tokens: list[bytearray] = [] @@ -2725,7 +2728,8 @@ def set_vocab(self): tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab())) assert max(tokenizer.get_vocab().values()) < vocab_size - + print(vocab_size) + print(max(tokenizer.get_vocab().values())) for token_id in range(vocab_size): piece = tokenizer._convert_id_to_token(token_id) if token_id == 0: @@ -2774,6 +2778,91 @@ def set_vocab(self): special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) special_vocab.add_to_gguf(self.gguf_writer) + @staticmethod + def token_bytes_to_string(b): + from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode + byte_encoder = bytes_to_unicode() + return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')]) + + @staticmethod + def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]: + parts = [bytes([b]) for b in token] + while True: + min_idx = None + min_rank = None + for i, pair in enumerate(zip(parts[:-1], parts[1:])): + rank = mergeable_ranks.get(pair[0] + pair[1]) + if rank is not None and (min_rank is None or rank < min_rank): + min_idx = i + min_rank = rank + if min_rank is None or (max_rank is not None and min_rank >= max_rank): + break + assert min_idx is not None + parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:] + return parts + + def set_vocab(self): + if "THUDM/chatglm3-6b" in self.hparams.get("_name_or_path", ""): + self.set_vocab_chatglm3() + return + + dir_model = self.dir_model + hparams = self.hparams + tokens: list[str] = [] + toktypes: list[int] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) + vocab_size = hparams["padded_vocab_size"] + assert max(tokenizer.get_vocab().values()) < vocab_size + + tokpre = self.get_vocab_base_pre(tokenizer) + + merges = [] + vocab = {} + mergeable_ranks = tokenizer.mergeable_ranks + for token, rank in mergeable_ranks.items(): + vocab[ChatGLMModel.token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank) + assert len(merged) >= 2 and len(merged) <= 7 + merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged))) + + # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined + added_vocab = tokenizer.get_added_vocab() + reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()} + + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.USER_DEFINED) + elif reverse_vocab[i] in added_vocab: + tokens.append(reverse_vocab[i]) + if tokenizer.added_tokens_decoder[i].special: + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.USER_DEFINED) + else: + tokens.append(reverse_vocab[i]) + toktypes.append(gguf.TokenType.NORMAL) + + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(dir_model, load_merges=False) + special_vocab.chat_template = "ChatGLM4" + special_vocab.merges = merges + # only add special tokens when they were not already loaded from config.json + if len(special_vocab.special_token_ids) == 0: + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) + # this one is usually not in config.json anyway + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab.add_to_gguf(self.gguf_writer) + def set_gguf_parameters(self): self.gguf_writer.add_name(self.dir_model.name) n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) @@ -2934,7 +3023,8 @@ def main() -> None: with torch.inference_mode(): model_class = Model.from_model_architecture(hparams["architectures"][0]) model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy) - + print(model_class) + print(model_instance) logger.info("Set model parameters") model_instance.set_gguf_parameters() diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e9904263d53c7..0788063575edf 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3056,7 +3056,7 @@ int main(int argc, char ** argv) { chat.push_back({{"role", "user"}, {"content", "Hello"}}); chat.push_back({{"role", "assistant"}, {"content", "Hi there"}}); chat.push_back({{"role", "user"}, {"content", "How are you?"}}); - + printf("sparams.chat_template: #%s#\n", sparams.chat_template.c_str()); const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat); LOG_INFO("chat template", { diff --git a/llama.cpp b/llama.cpp index 154168ef3376b..a0255bac894b3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4508,6 +4508,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 28: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_8B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -4636,9 +4637,9 @@ static void llm_load_vocab( if (merges_keyidx == -1) { throw std::runtime_error("cannot find tokenizer merges in model file\n"); } - + printf("merges_keyidx: %d\n", merges_keyidx); const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - + printf("n_merges: %d\n", n_merges); for (int i = 0; i < n_merges; i++) { const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); @@ -4728,6 +4729,9 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "smaug-bpe") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG; + } else if ( + tokenizer_pre == "chatglm-bpe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -11449,7 +11453,7 @@ struct llm_build_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - + //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, @@ -13032,6 +13036,7 @@ struct llm_tokenizer_bpe { break; case LLAMA_VOCAB_PRE_TYPE_DBRX: case LLAMA_VOCAB_PRE_TYPE_SMAUG: + case LLAMA_VOCAB_PRE_TYPE_CHATGLM4: word_collection = unicode_regex_split(text, { // same as llama3 "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", @@ -18741,6 +18746,15 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>"; } + } else if (tmpl == "ChatGLM4") { + ss << "[gMASK]" << ""; + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n" << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } } else { // template not supported return -1; diff --git a/llama.h b/llama.h index 3e4474bb94e9a..a670e19112755 100644 --- a/llama.h +++ b/llama.h @@ -86,6 +86,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_OLMO = 12, LLAMA_VOCAB_PRE_TYPE_DBRX = 13, LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 15, }; // note: these values should be synchronized with ggml_rope diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 87f39f1039441..0fe4d29674269 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -59,6 +59,8 @@ int main(void) { "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}", // ChatGLM3 "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", + // ChatGLM4 + "ChatGLM4", }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B @@ -97,6 +99,8 @@ int main(void) { "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", // ChatGLM3 "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>", + // ChatGLM4 + "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>", }; std::vector formatted_chat(1024); int32_t res; From 8c5f1b2b6c4d8d5afde26769b9721f4cb6ec5665 Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Thu, 20 Jun 2024 08:10:00 +0000 Subject: [PATCH 06/28] fix eos tokens to glm4 --- convert-hf-to-gguf.py | 17 ++++++++++++----- llama.cpp | 35 +++++++++++++++++++++++++++++++---- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index b36b5193c7881..dc70e26d56f63 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2728,6 +2728,8 @@ def set_vocab_chatglm3(self): tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab())) assert max(tokenizer.get_vocab().values()) < vocab_size + role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"] + special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens print(vocab_size) print(max(tokenizer.get_vocab().values())) for token_id in range(vocab_size): @@ -2750,7 +2752,12 @@ def set_vocab_chatglm3(self): text = f"[PAD{token_id}]".encode("utf-8") if token_id >= tokenizer.tokenizer.sp_model.vocab_size(): - toktype = SentencePieceTokenTypes.UNKNOWN + if piece in special_tokens: + # show special tokens in prompt + toktype = SentencePieceTokenTypes.USER_DEFINED + else: + print(f"unknow token: {piece}") + toktype = SentencePieceTokenTypes.UNKNOWN tokens.append(text) scores.append(score) toktypes.append(toktype) @@ -2856,9 +2863,9 @@ def set_vocab(self): special_vocab.chat_template = "ChatGLM4" special_vocab.merges = merges # only add special tokens when they were not already loaded from config.json - if len(special_vocab.special_token_ids) == 0: - special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) - special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) + # if len(special_vocab.special_token_ids) == 0: + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) # this one is usually not in config.json anyway special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) special_vocab.add_to_gguf(self.gguf_writer) @@ -2955,7 +2962,7 @@ def parse_args() -> argparse.Namespace: help="model is executed on big endian machine", ) parser.add_argument( - "model", type=Path, + "--model", type=Path, help="directory containing model file", ) parser.add_argument( diff --git a/llama.cpp b/llama.cpp index a0255bac894b3..9e23f66435e64 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1802,9 +1802,11 @@ enum e_model { MODEL_2_8B, MODEL_3B, MODEL_4B, + MODEL_6B, MODEL_6_9B, MODEL_7B, MODEL_8B, + MODEL_9B, MODEL_12B, MODEL_13B, MODEL_14B, @@ -3918,9 +3920,11 @@ static const char * llama_model_type_name(e_model type) { case MODEL_2_8B: return "2.8B"; case MODEL_3B: return "3B"; case MODEL_4B: return "4B"; + case MODEL_6B: return "6B"; case MODEL_6_9B: return "6.9B"; case MODEL_7B: return "7B"; case MODEL_8B: return "8B"; + case MODEL_9B: return "9B"; case MODEL_12B: return "12B"; case MODEL_13B: return "13B"; case MODEL_14B: return "14B"; @@ -4507,8 +4511,8 @@ static void llm_load_hparams( { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 28: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_8B; break; + case 28: model.type = e_model::MODEL_6B; break; + case 40: model.type = e_model::MODEL_9B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -18362,6 +18366,19 @@ llama_token_type llama_token_get_type(const struct llama_model * model, llama_to } bool llama_token_is_eog(const struct llama_model * model, llama_token token) { + auto arch_name = llama_model_arch_name(model->arch); + auto vocab_type = model->vocab.type; + if (strcmp(arch_name, "chatglm") == 0) { + if (LLAMA_VOCAB_TYPE_BPE == vocab_type) { // glm4 + return token != -1 && ( + token == llama_token_eos(model) || + token == llama_token_eot(model) || + token == 151329 || + token == 151336 || + token == 151338 + ); + } + } return token != -1 && ( token == llama_token_eos(model) || token == llama_token_eot(model) @@ -18424,8 +18441,18 @@ int32_t llama_tokenize( int32_t n_tokens_max, bool add_special, bool parse_special) { - auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special); - + auto arch_name = llama_model_arch_name(model->arch); + auto prompt = std::move(std::string(text, text_len)); + auto vocab_type = model->vocab.type; + if (strcmp(arch_name, "chatglm") == 0) { + // chatglm3 + if (LLAMA_VOCAB_TYPE_SPM == vocab_type) { + prompt = "[gMASK]sop<|user|>\n" + prompt + "<|assistant|>"; + } else if (LLAMA_VOCAB_TYPE_BPE == vocab_type) { // glm4 + prompt = "[gMASK]<|user|>\n" + prompt + "<|assistant|>"; + } + } + auto res = llama_tokenize_internal(model->vocab, prompt, add_special, parse_special); if (n_tokens_max < (int) res.size()) { // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); return -((int) res.size()); From 95fd910d32a288d0aaad4146a580d0861c764a66 Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Thu, 20 Jun 2024 08:20:12 +0000 Subject: [PATCH 07/28] remove unused log --- convert-hf-to-gguf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index dc70e26d56f63..4595ee60c6cf2 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2756,7 +2756,6 @@ def set_vocab_chatglm3(self): # show special tokens in prompt toktype = SentencePieceTokenTypes.USER_DEFINED else: - print(f"unknow token: {piece}") toktype = SentencePieceTokenTypes.UNKNOWN tokens.append(text) scores.append(score) From 4b65b648ce4a06b566cffdfb0ee09b55b79a15bb Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Fri, 21 Jun 2024 07:47:51 +0000 Subject: [PATCH 08/28] add preprocess to chatglm3 and chatglm4 --- convert-hf-to-gguf.py | 3 +++ llama.cpp | 41 ++++++++++++++++++++++++++++------------- llama.h | 3 ++- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index c5cb8bbeca6da..3305b8cebb8f8 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2792,6 +2792,9 @@ def set_vocab_chatglm3(self): toktypes.append(toktype) self.gguf_writer.add_tokenizer_model("llama") + # glm3 needs prefix and suffix formatted as: + # prompt = "[gMASK]sop<|user|>\n" + prompt + "<|assistant|>" + self.gguf_writer.add_tokenizer_pre("chatglm-spm") self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_scores(scores) self.gguf_writer.add_token_types(toktypes) diff --git a/llama.cpp b/llama.cpp index a2df298a8b19c..a2ac68379b856 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4789,6 +4789,10 @@ static void llm_load_vocab( return; } else if (tokenizer_model == "llama") { vocab.type = LLAMA_VOCAB_TYPE_SPM; + // chatglm3 needs to preprocess prefix and suffix + if (tokenizer_pre == "chatglm-spm") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM3; + } // default special tokens vocab.special_bos_id = 1; @@ -13923,6 +13927,14 @@ static std::vector llama_tokenize_internal(const llama_vocab & output.push_back(vocab.special_bos_id); is_prev_special = true; } + // add prefix to chatglm3 + if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM3) { + output.push_back(64790); + output.push_back(64792); + output.push_back(64795); + output.push_back(30910); + output.push_back(13); + } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -13957,6 +13969,10 @@ static std::vector llama_tokenize_internal(const llama_vocab & GGML_ASSERT(vocab.special_eos_id != -1); output.push_back(vocab.special_eos_id); } + // add suffix to chatglm3 + if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM3) { + output.push_back(64796); + } } break; case LLAMA_VOCAB_TYPE_BPE: { @@ -13965,7 +13981,13 @@ static std::vector llama_tokenize_internal(const llama_vocab & if (add_special) { tokenizer.append_bos(output); } - + // add prefix to chatglm4 + if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM4) { + output.push_back(151331); + output.push_back(151333); + output.push_back(151336); + output.push_back(198); + } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); @@ -13983,6 +14005,10 @@ static std::vector llama_tokenize_internal(const llama_vocab & tokenizer.append_eos(output); tokenizer.check_double_bos_eos(output); } + // add suffix to chatglm4 + if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM4) { + output.push_back(151337); + } } break; case LLAMA_VOCAB_TYPE_WPM: { @@ -18599,18 +18625,7 @@ int32_t llama_tokenize( int32_t n_tokens_max, bool add_special, bool parse_special) { - auto arch_name = llama_model_arch_name(model->arch); - auto prompt = std::move(std::string(text, text_len)); - auto vocab_type = model->vocab.type; - if (strcmp(arch_name, "chatglm") == 0) { - // chatglm3 - if (LLAMA_VOCAB_TYPE_SPM == vocab_type) { - prompt = "[gMASK]sop<|user|>\n" + prompt + "<|assistant|>"; - } else if (LLAMA_VOCAB_TYPE_BPE == vocab_type) { // glm4 - prompt = "[gMASK]<|user|>\n" + prompt + "<|assistant|>"; - } - } - auto res = llama_tokenize_internal(model->vocab, prompt, add_special, parse_special); + auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special); if (n_tokens_max < (int) res.size()) { // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); return -((int) res.size()); diff --git a/llama.h b/llama.h index b1ff05bd719be..a85b568b9aae9 100644 --- a/llama.h +++ b/llama.h @@ -87,7 +87,8 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_DBRX = 13, LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, LLAMA_VOCAB_PRE_TYPE_PORO = 15, - LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 16, + LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, }; // note: these values should be synchronized with ggml_rope From 3a4d5790bfdc205c5b658204239f168fc21cc1a8 Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Mon, 24 Jun 2024 12:27:02 +0000 Subject: [PATCH 09/28] add eos_id_list to llama.cpp --- common/common.cpp | 26 ++++++++-- common/train.cpp | 6 ++- convert-hf-to-gguf.py | 9 ++-- convert-llama-ggml-to-gguf.py | 2 +- examples/gritlm/gritlm.cpp | 3 +- examples/retrieval/retrieval.cpp | 9 +++- examples/server/server.cpp | 22 +++++++-- examples/speculative/speculative.cpp | 11 ++++- gguf-py/gguf/constants.py | 6 ++- gguf-py/gguf/gguf_writer.py | 6 +-- llama.cpp | 72 +++++++++++++++++----------- llama.h | 3 +- tests/test-chat-template.cpp | 2 +- 13 files changed, 122 insertions(+), 55 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 73ff0e85b7b4e..657d2ffa84548 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2417,14 +2417,21 @@ std::tuple llama_init_from_gpt_par } } + const int n_eos = llama_n_eos(llama_get_model(lctx)); + std::vector eos_tokens(n_eos, 0); + int32_t* eos_ptr = eos_tokens.data(); + llama_token_eos(llama_get_model(lctx), eos_ptr); if (params.ignore_eos) { - params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + for (int32_t i = 0; i < n_eos; ++i) { + params.sparams.logit_bias[eos_ptr[i]] = -INFINITY; + } } if (params.warmup) { LOG("warming up the model with an empty run\n"); - std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; + std::vector tmp = { llama_token_bos(model) }; + tmp.insert(tmp.end(), eos_tokens.begin(), eos_tokens.end()); llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_kv_cache_clear(lctx); llama_synchronize(lctx); @@ -3357,8 +3364,17 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); - const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx))); - const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; + const int n_eos = llama_n_eos(llama_get_model(lctx)); + std::vector eos_tokens(n_eos, 0); + int32_t* eos_ptr = eos_tokens.data(); + llama_token_eos(llama_get_model(lctx), eos_ptr); + bool ignore_eos = false; + for (auto eos: eos_tokens) { + const auto logit_bias_eos = sparams.logit_bias.find(eos); + if (logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY) { + ignore_eos = true; + } + } fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str()); @@ -3371,7 +3387,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "logit_bias:\n"); for (std::pair lb : sparams.logit_bias) { - if (ignore_eos && lb.first == logit_bias_eos->first) { + if (ignore_eos && std::count(eos_tokens.begin(), eos_tokens.end(), lb.first)) { continue; } fprintf(stream, " %d: %f", lb.first, lb.second); diff --git a/common/train.cpp b/common/train.cpp index fef1e57c94655..96ea4165ee4bb 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -240,7 +240,11 @@ int64_t get_example_targets_batch( ggml_set_f32(target_probs, 0.0f); llama_token bos = llama_token_bos(llama_get_model(lctx)); - llama_token eos = llama_token_eos(llama_get_model(lctx)); + const int n_eos = llama_n_eos(llama_get_model(lctx)); + std::vector eos_tokens(n_eos, 0); + int32_t* eos_ptr = eos_tokens.data(); + llama_token_eos(llama_get_model(lctx), eos_ptr); + llama_token eos = eos_ptr[0]; // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples); for (int k=0; k= 0 && (inp.empty() || inp.back() != llama_token_eos(model))) { - inp.push_back(llama_token_eos(model)); + const int n_eos = llama_n_eos(model); + std::vector eos_tokens(n_eos, 0); + int32_t* eos_ptr = eos_tokens.data(); + llama_token_eos(model, eos_ptr); + + if (!eos_tokens.empty() && (inp.empty() || std::count(eos_tokens.begin(), eos_tokens.end(), inp.back()))) { + inp.insert(inp.end(), eos_tokens.begin(), eos_tokens.end()); } chunk.tokens = inp; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f9a86961f9c8e..5be18bc5456ca 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1021,7 +1021,13 @@ struct server_context { slot.sparams.logit_bias.clear(); if (json_value(data, "ignore_eos", false)) { - slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + const int n_eos = llama_n_eos(model); + std::vector eos_tokens(n_eos, 0); + int32_t* eos_ptr = eos_tokens.data(); + llama_token_eos(model, eos_ptr); + for (int32_t i = 0; i < n_eos; ++i) { + slot.sparams.logit_bias[eos_ptr[i]] = -INFINITY; + } } const auto & logit_bias = data.find("logit_bias"); @@ -1308,9 +1314,17 @@ struct server_context { } json get_formated_generation(const server_slot & slot) const { - const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); - const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); - + const int n_eos = llama_n_eos(model); + std::vector eos_tokens(n_eos, 0); + int32_t* eos_ptr = eos_tokens.data(); + llama_token_eos(model, eos_ptr); + bool ignore_eos = false; + for (auto eos: eos_tokens) { + const auto logit_bias_eos = slot.sparams.logit_bias.find(eos); + if (logit_bias_eos != slot.sparams.logit_bias.end() && eos < 0.0f && std::isinf(logit_bias_eos->second)) { + ignore_eos = true; + } + } std::vector samplers_sequence; samplers_sequence.reserve(slot.sparams.samplers_sequence.size()); for (const auto & sampler_type : slot.sparams.samplers_sequence) { diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 0939a1a6a7a38..43b4278ae23e2 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -88,12 +88,21 @@ int main(int argc, char ** argv) { fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt); return 1; } + const int n_eos_tgt = llama_n_eos(model_tgt); + std::vector eos_tokens_tgt(n_eos_tgt, 0); + int32_t* eos_ptr_tgt = eos_tokens_tgt.data(); + llama_token_eos(model_tgt, eos_ptr_tgt); + + const int n_eos_dft = llama_n_eos(model_dft); + std::vector eos_tokens_dft(n_eos_dft, 0); + int32_t* eos_ptr_dft = eos_tokens_dft.data(); + llama_token_eos(model_dft, eos_ptr_dft); if ( llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || llama_token_bos(model_tgt) != llama_token_bos(model_dft) || - llama_token_eos(model_tgt) != llama_token_eos(model_dft) + eos_tokens_tgt != eos_tokens_dft ) { fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__); return 1; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 65aa3298df911..cb512c4f28c43 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -88,7 +88,7 @@ class Tokenizer: SCORES = "tokenizer.ggml.scores" MERGES = "tokenizer.ggml.merges" BOS_ID = "tokenizer.ggml.bos_token_id" - EOS_ID = "tokenizer.ggml.eos_token_id" + EOS_ID = "tokenizer.ggml.eos_token_id" # recommand eos_id_list UNK_ID = "tokenizer.ggml.unknown_token_id" SEP_ID = "tokenizer.ggml.seperator_token_id" PAD_ID = "tokenizer.ggml.padding_token_id" @@ -107,6 +107,8 @@ class Tokenizer: SUFFIX_ID = "tokenizer.ggml.suffix_token_id" MIDDLE_ID = "tokenizer.ggml.middle_token_id" EOT_ID = "tokenizer.ggml.eot_token_id" + EOS_ID_LIST = "tokenizer.ggml.eos_token_id_list" + # @@ -1091,7 +1093,7 @@ def get_type(val: Any) -> GGUFValueType: KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID -KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID +KEY_TOKENIZER_EOS_ID_LIST= Keys.Tokenizer.EOS_ID_LIST KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a697f657b9ac8..e4f6868d9779d 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -510,9 +510,9 @@ def add_token_scores(self, scores: Sequence[float]) -> None: def add_bos_token_id(self, id: int) -> None: self.add_uint32(Keys.Tokenizer.BOS_ID, id) - - def add_eos_token_id(self, id: int) -> None: - self.add_uint32(Keys.Tokenizer.EOS_ID, id) + + def add_eos_token_id_list(self, id: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None: + self.add_array(Keys.Tokenizer.EOS_ID_LIST, id) def add_unk_token_id(self, id: int) -> None: self.add_uint32(Keys.Tokenizer.UNK_ID, id) diff --git a/llama.cpp b/llama.cpp index a2ac68379b856..ea5c76cac4f5b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -337,7 +337,7 @@ enum llm_kv { LLM_KV_TOKENIZER_SCORES, LLM_KV_TOKENIZER_MERGES, LLM_KV_TOKENIZER_BOS_ID, - LLM_KV_TOKENIZER_EOS_ID, + LLM_KV_TOKENIZER_EOS_ID, //compatibility with previous versions LLM_KV_TOKENIZER_UNK_ID, LLM_KV_TOKENIZER_SEP_ID, LLM_KV_TOKENIZER_PAD_ID, @@ -352,6 +352,7 @@ enum llm_kv { LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, LLM_KV_TOKENIZER_EOT_ID, + LLM_KV_TOKENIZER_EOS_ID_LIST }; static const std::map LLM_KV_NAMES = { @@ -438,6 +439,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID_LIST, "tokenizer.ggml.eos_token_id_list" }, }; struct LLM_KV { @@ -2328,6 +2330,7 @@ struct llama_vocab { id special_pad_id = -1; id special_cls_id = -1; id special_mask_id = -1; + std::set special_eos_id_list; id linefeed_id = 13; id special_prefix_id = -1; @@ -5084,6 +5087,24 @@ static void llm_load_vocab( } } + const int eos_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_EOS_ID_LIST).c_str()); + if (eos_idx == -1) { + vocab.special_eos_id_list.clear(); + vocab.special_eos_id_list.insert(vocab.special_eos_id); + } else { + const uint32_t n_eos = gguf_get_arr_n(ctx, eos_idx); + const int* eos_tokens = (const int*)gguf_get_arr_data(ctx, eos_idx); + if (n_eos > 0) { + vocab.special_eos_id_list.clear(); + } else { + vocab.special_eos_id_list.clear(); + vocab.special_eos_id_list.insert(vocab.special_eos_id); + } + for (uint32_t i = 0; i < n_eos; ++i) { + vocab.special_eos_id_list.insert(eos_tokens[i]); + } + } + // Handle add_bos_token and add_eos_token { bool temp = true; @@ -5273,7 +5294,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { // special tokens if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } - if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } + if (!vocab.special_eos_id_list.empty()) { + for (auto it = vocab.special_eos_id_list.begin(); it != vocab.special_eos_id_list.end(); ++it) { + LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, *it, vocab.id_to_token[*it].text.c_str() ); + } + } if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } @@ -13482,8 +13507,8 @@ struct llm_tokenizer_bpe { bool append_eos(std::vector & output) const { if (vocab.tokenizer_add_eos) { - GGML_ASSERT(vocab.special_eos_id != -1); - output.push_back(vocab.special_eos_id); + GGML_ASSERT(!vocab.special_eos_id_list.empty()); + output.insert(output.end(), vocab.special_eos_id_list.begin(), vocab.special_eos_id_list.end()); return true; } return false; @@ -13496,7 +13521,7 @@ struct llm_tokenizer_bpe { "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " "Are you sure this is what you want?\n", __FUNCTION__); } - if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) { + if (vocab.tokenizer_add_eos && output.size() >= 2 && vocab.special_eos_id_list.find(*(output.end()-2)) != vocab.special_eos_id_list.end()) { LLAMA_LOG_WARN( "%s: Added a EOS token to the prompt as specified by the model but the prompt " "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. " @@ -13966,8 +13991,8 @@ static std::vector llama_tokenize_internal(const llama_vocab & } if (add_special && vocab.tokenizer_add_eos) { - GGML_ASSERT(vocab.special_eos_id != -1); - output.push_back(vocab.special_eos_id); + GGML_ASSERT(!vocab.special_eos_id_list.empty()); + output.insert(output.end(), vocab.special_eos_id_list.begin(), vocab.special_eos_id_list.end()); } // add suffix to chatglm3 if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM3) { @@ -16966,6 +16991,10 @@ int32_t llama_n_vocab(const struct llama_model * model) { return model->hparams.n_vocab; } +int32_t llama_n_eos(const struct llama_model * model) { + return model->vocab.special_eos_id_list.size(); +} + int32_t llama_n_ctx_train(const struct llama_model * model) { return model->hparams.n_ctx_train; } @@ -18550,21 +18579,8 @@ llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_to } bool llama_token_is_eog(const struct llama_model * model, llama_token token) { - auto arch_name = llama_model_arch_name(model->arch); - auto vocab_type = model->vocab.type; - if (strcmp(arch_name, "chatglm") == 0) { - if (LLAMA_VOCAB_TYPE_BPE == vocab_type) { // glm4 - return token != -1 && ( - token == llama_token_eos(model) || - token == llama_token_eot(model) || - token == 151329 || - token == 151336 || - token == 151338 - ); - } - } return token != -1 && ( - token == llama_token_eos(model) || + model->vocab.special_eos_id_list.count(token) || token == llama_token_eot(model) ); } @@ -18577,8 +18593,11 @@ llama_token llama_token_bos(const struct llama_model * model) { return model->vocab.special_bos_id; } -llama_token llama_token_eos(const struct llama_model * model) { - return model->vocab.special_eos_id; +void llama_token_eos(const struct llama_model * model, llama_token* token_list) { + int ind = 0; + for (auto it = model->vocab.special_eos_id_list.begin(); it != model->vocab.special_eos_id_list.end(); ++it) { + token_list[ind++] = *it; + } } llama_token llama_token_cls(const struct llama_model * model) { @@ -18952,10 +18971,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; } - } else if (tmpl == "chatglm3" || - (tmpl.find("add_generation_prompt") != std::string::npos && - tmpl.find("for message in messages") != std::string::npos && - tmpl.find("loop.first") != std::string::npos)) { + } else if (tmpl == "chatglm3" || tmpl.find("[gMASK]sop") != std::string::npos) { // chatglm3-6b ss << "[gMASK]" << "sop"; for (auto message : chat) { @@ -18965,7 +18981,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>"; } - } else if (tmpl == "ChatGLM4") { + } else if (tmpl == "chatglm4" || tmpl.find("[gMASK]") != std::string::npos) { ss << "[gMASK]" << ""; for (auto message : chat) { std::string role(message->role); diff --git a/llama.h b/llama.h index a85b568b9aae9..84172f44353d0 100644 --- a/llama.h +++ b/llama.h @@ -448,6 +448,7 @@ extern "C" { LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); + LLAMA_API int32_t llama_n_eos (const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_n_embd (const struct llama_model * model); LLAMA_API int32_t llama_n_layer (const struct llama_model * model); @@ -851,7 +852,7 @@ extern "C" { // Special tokens LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence - LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence + LLAMA_API void llama_token_eos(const struct llama_model * model, llama_token* token_list); // end-of-sentence LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 0fe4d29674269..399dc57b9374f 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -60,7 +60,7 @@ int main(void) { // ChatGLM3 "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", // ChatGLM4 - "ChatGLM4", + "chatglm4", }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B From 3b67ff808a93b95f349890a13f2cfc62dc1988fb Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Tue, 25 Jun 2024 02:22:55 +0000 Subject: [PATCH 10/28] fix code style --- convert-hf-to-gguf.py | 2 +- gguf-py/gguf/constants.py | 2 +- gguf-py/gguf/gguf_writer.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index b56bb8afc012b..97b02f6193658 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2895,6 +2895,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] + @Model.register("ChatGLMModel") class ChatGLMModel(Model): model_arch = gguf.MODEL_ARCH.CHATGLM @@ -3081,7 +3082,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] - ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 33e6463c5d053..d08d154ad2e3a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -119,7 +119,6 @@ class Tokenizer: EOS_ID_LIST = "tokenizer.ggml.eos_token_id_list" - # # recommended mapping of model tensor names for storage in gguf # @@ -164,6 +163,7 @@ class MODEL_ARCH(IntEnum): BITNET = auto() T5 = auto() + class MODEL_TENSOR(IntEnum): TOKEN_EMBD = auto() TOKEN_EMBD_NORM = auto() diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index e5ad3c76e8f77..b7bbaeee2d12a 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -608,7 +608,7 @@ def add_token_scores(self, scores: Sequence[float]) -> None: def add_bos_token_id(self, id: int) -> None: self.add_uint32(Keys.Tokenizer.BOS_ID, id) - + def add_eos_token_id_list(self, id: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None: self.add_array(Keys.Tokenizer.EOS_ID_LIST, id) From 5f8f465d0d4e5b89406d9600230ae81878a5964e Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Tue, 25 Jun 2024 02:29:09 +0000 Subject: [PATCH 11/28] fix code style --- gguf-py/gguf/constants.py | 44 +++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index d08d154ad2e3a..db1ca76f56d97 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -138,28 +138,28 @@ class MODEL_ARCH(IntEnum): BERT = auto() NOMIC_BERT = auto() JINA_BERT_V2 = auto() - BLOOM = auto() - STABLELM = auto() - QWEN = auto() - QWEN2 = auto() - QWEN2MOE = auto() - PHI2 = auto() - PHI3 = auto() - PLAMO = auto() - CODESHELL = auto() - ORION = auto() - INTERNLM2 = auto() - MINICPM = auto() - GEMMA = auto() - STARCODER2 = auto() - MAMBA = auto() - XVERSE = auto() - COMMAND_R = auto() - DBRX = auto() - OLMO = auto() - ARCTIC = auto() - DEEPSEEK2 = auto() - CHATGLM = auto() + BLOOM = auto() + STABLELM = auto() + QWEN = auto() + QWEN2 = auto() + QWEN2MOE = auto() + PHI2 = auto() + PHI3 = auto() + PLAMO = auto() + CODESHELL = auto() + ORION = auto() + INTERNLM2 = auto() + MINICPM = auto() + GEMMA = auto() + STARCODER2 = auto() + MAMBA = auto() + XVERSE = auto() + COMMAND_R = auto() + DBRX = auto() + OLMO = auto() + ARCTIC = auto() + DEEPSEEK2 = auto() + CHATGLM = auto() BITNET = auto() T5 = auto() From f8d4fc987ebb7c83ddbaae62bc090709e2d43232 Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Tue, 25 Jun 2024 03:09:49 +0000 Subject: [PATCH 12/28] fix conflicts --- llama.cpp | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/llama.cpp b/llama.cpp index 4de502786ae9a..3a9806a9b41d7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4768,14 +4768,14 @@ static void llm_load_hparams( } } break; case LLM_ARCH_CHATGLM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 28: model.type = e_model::MODEL_6B; break; - case 40: model.type = e_model::MODEL_9B; break; - default: model.type = e_model::MODEL_UNKNOWN; - } - } break; + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: model.type = e_model::MODEL_6B; break; + case 40: model.type = e_model::MODEL_9B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -11966,6 +11966,11 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + nullptr, nullptr, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_sub_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur, "attn_sub_norm", il); @@ -12153,7 +12158,6 @@ struct llm_build_context { return gf; } - }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { From a67bc8f5a8438f5540eecf175368b5db5be6f3f6 Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Tue, 25 Jun 2024 06:00:43 +0000 Subject: [PATCH 13/28] fix conflicts --- llama.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/llama.cpp b/llama.cpp index ea5c76cac4f5b..d5b7b4165eb0f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4792,10 +4792,6 @@ static void llm_load_vocab( return; } else if (tokenizer_model == "llama") { vocab.type = LLAMA_VOCAB_TYPE_SPM; - // chatglm3 needs to preprocess prefix and suffix - if (tokenizer_pre == "chatglm-spm") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM3; - } // default special tokens vocab.special_bos_id = 1; @@ -4944,6 +4940,13 @@ static void llm_load_vocab( vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; vocab.tokenizer_add_bos = true; vocab.tokenizer_add_eos = false; + // chatglm3 needs to preprocess prefix and suffix + if (tokenizer_pre == "chatglm-spm") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM3; + vocab.tokenizer_add_bos = false; + vocab.tokenizer_add_eos = false; + vocab.tokenizer_add_space_prefix = false; + } } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; vocab.tokenizer_add_bos = true; @@ -5040,7 +5043,7 @@ static void llm_load_vocab( vocab.special_eot_id = 107; } } - + try { vocab.linefeed_id = llama_byte_to_token(vocab, '\n'); } catch (const std::exception & e) { @@ -13946,7 +13949,6 @@ static std::vector llama_tokenize_internal(const llama_vocab & // tokenizer.encode('', add_special_tokens=False) returns [] bool is_prev_special = false; - if (add_special && vocab.tokenizer_add_bos) { GGML_ASSERT(vocab.special_bos_id != -1); output.push_back(vocab.special_bos_id); From 89e8aaf960c5a45a7956dd99cc547fddbd95c738 Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Tue, 25 Jun 2024 09:23:57 +0000 Subject: [PATCH 14/28] Revert "add eos_id_list to llama.cpp" This reverts commit 3a4d5790bfdc205c5b658204239f168fc21cc1a8. --- common/common.cpp | 26 ++-------- common/train.cpp | 6 +-- convert-hf-to-gguf.py | 9 ++-- convert-llama-ggml-to-gguf.py | 2 +- examples/gritlm/gritlm.cpp | 3 +- examples/retrieval/retrieval.cpp | 9 +--- examples/server/server.cpp | 22 ++------- examples/speculative/speculative.cpp | 11 +---- gguf-py/gguf/constants.py | 6 +-- gguf-py/gguf/gguf_writer.py | 4 +- llama.cpp | 72 +++++++++++----------------- llama.h | 3 +- tests/test-chat-template.cpp | 2 +- 13 files changed, 54 insertions(+), 121 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 72098de04f00a..1dc53265134a7 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2053,21 +2053,14 @@ std::tuple llama_init_from_gpt_par } } - const int n_eos = llama_n_eos(llama_get_model(lctx)); - std::vector eos_tokens(n_eos, 0); - int32_t* eos_ptr = eos_tokens.data(); - llama_token_eos(llama_get_model(lctx), eos_ptr); if (params.ignore_eos) { - for (int32_t i = 0; i < n_eos; ++i) { - params.sparams.logit_bias[eos_ptr[i]] = -INFINITY; - } + params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; } if (params.warmup) { LOG("warming up the model with an empty run\n"); - std::vector tmp = { llama_token_bos(model) }; - tmp.insert(tmp.end(), eos_tokens.begin(), eos_tokens.end()); + std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_kv_cache_clear(lctx); llama_synchronize(lctx); @@ -3035,17 +3028,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); - const int n_eos = llama_n_eos(llama_get_model(lctx)); - std::vector eos_tokens(n_eos, 0); - int32_t* eos_ptr = eos_tokens.data(); - llama_token_eos(llama_get_model(lctx), eos_ptr); - bool ignore_eos = false; - for (auto eos: eos_tokens) { - const auto logit_bias_eos = sparams.logit_bias.find(eos); - if (logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY) { - ignore_eos = true; - } - } + const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx))); + const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str()); @@ -3058,7 +3042,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "logit_bias:\n"); for (std::pair lb : sparams.logit_bias) { - if (ignore_eos && std::count(eos_tokens.begin(), eos_tokens.end(), lb.first)) { + if (ignore_eos && lb.first == logit_bias_eos->first) { continue; } fprintf(stream, " %d: %f", lb.first, lb.second); diff --git a/common/train.cpp b/common/train.cpp index 96ea4165ee4bb..fef1e57c94655 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -240,11 +240,7 @@ int64_t get_example_targets_batch( ggml_set_f32(target_probs, 0.0f); llama_token bos = llama_token_bos(llama_get_model(lctx)); - const int n_eos = llama_n_eos(llama_get_model(lctx)); - std::vector eos_tokens(n_eos, 0); - int32_t* eos_ptr = eos_tokens.data(); - llama_token_eos(llama_get_model(lctx), eos_ptr); - llama_token eos = eos_ptr[0]; + llama_token eos = llama_token_eos(llama_get_model(lctx)); // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples); for (int k=0; k eos_tokens(n_eos, 0); - int32_t* eos_ptr = eos_tokens.data(); - llama_token_eos(model, eos_ptr); - - if (!eos_tokens.empty() && (inp.empty() || std::count(eos_tokens.begin(), eos_tokens.end(), inp.back()))) { - inp.insert(inp.end(), eos_tokens.begin(), eos_tokens.end()); + if (llama_token_eos(model) >= 0 && (inp.empty() || inp.back() != llama_token_eos(model))) { + inp.push_back(llama_token_eos(model)); } chunk.tokens = inp; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5be18bc5456ca..f9a86961f9c8e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1021,13 +1021,7 @@ struct server_context { slot.sparams.logit_bias.clear(); if (json_value(data, "ignore_eos", false)) { - const int n_eos = llama_n_eos(model); - std::vector eos_tokens(n_eos, 0); - int32_t* eos_ptr = eos_tokens.data(); - llama_token_eos(model, eos_ptr); - for (int32_t i = 0; i < n_eos; ++i) { - slot.sparams.logit_bias[eos_ptr[i]] = -INFINITY; - } + slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; } const auto & logit_bias = data.find("logit_bias"); @@ -1314,17 +1308,9 @@ struct server_context { } json get_formated_generation(const server_slot & slot) const { - const int n_eos = llama_n_eos(model); - std::vector eos_tokens(n_eos, 0); - int32_t* eos_ptr = eos_tokens.data(); - llama_token_eos(model, eos_ptr); - bool ignore_eos = false; - for (auto eos: eos_tokens) { - const auto logit_bias_eos = slot.sparams.logit_bias.find(eos); - if (logit_bias_eos != slot.sparams.logit_bias.end() && eos < 0.0f && std::isinf(logit_bias_eos->second)) { - ignore_eos = true; - } - } + const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); + const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); + std::vector samplers_sequence; samplers_sequence.reserve(slot.sparams.samplers_sequence.size()); for (const auto & sampler_type : slot.sparams.samplers_sequence) { diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 43b4278ae23e2..0939a1a6a7a38 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -88,21 +88,12 @@ int main(int argc, char ** argv) { fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt); return 1; } - const int n_eos_tgt = llama_n_eos(model_tgt); - std::vector eos_tokens_tgt(n_eos_tgt, 0); - int32_t* eos_ptr_tgt = eos_tokens_tgt.data(); - llama_token_eos(model_tgt, eos_ptr_tgt); - - const int n_eos_dft = llama_n_eos(model_dft); - std::vector eos_tokens_dft(n_eos_dft, 0); - int32_t* eos_ptr_dft = eos_tokens_dft.data(); - llama_token_eos(model_dft, eos_ptr_dft); if ( llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || llama_token_bos(model_tgt) != llama_token_bos(model_dft) || - eos_tokens_tgt != eos_tokens_dft + llama_token_eos(model_tgt) != llama_token_eos(model_dft) ) { fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__); return 1; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index db1ca76f56d97..1e0afe9d39066 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -95,7 +95,7 @@ class Tokenizer: SCORES = "tokenizer.ggml.scores" MERGES = "tokenizer.ggml.merges" BOS_ID = "tokenizer.ggml.bos_token_id" - EOS_ID = "tokenizer.ggml.eos_token_id" # recommand eos_id_list + EOS_ID = "tokenizer.ggml.eos_token_id" UNK_ID = "tokenizer.ggml.unknown_token_id" SEP_ID = "tokenizer.ggml.seperator_token_id" PAD_ID = "tokenizer.ggml.padding_token_id" @@ -116,8 +116,6 @@ class Tokenizer: SUFFIX_ID = "tokenizer.ggml.suffix_token_id" MIDDLE_ID = "tokenizer.ggml.middle_token_id" EOT_ID = "tokenizer.ggml.eot_token_id" - EOS_ID_LIST = "tokenizer.ggml.eos_token_id_list" - # # recommended mapping of model tensor names for storage in gguf @@ -1212,7 +1210,7 @@ def get_type(val: Any) -> GGUFValueType: KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID -KEY_TOKENIZER_EOS_ID_LIST= Keys.Tokenizer.EOS_ID_LIST +KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index b7bbaeee2d12a..9869f6fe3445a 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -609,8 +609,8 @@ def add_token_scores(self, scores: Sequence[float]) -> None: def add_bos_token_id(self, id: int) -> None: self.add_uint32(Keys.Tokenizer.BOS_ID, id) - def add_eos_token_id_list(self, id: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None: - self.add_array(Keys.Tokenizer.EOS_ID_LIST, id) + def add_eos_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.EOS_ID, id) def add_unk_token_id(self, id: int) -> None: self.add_uint32(Keys.Tokenizer.UNK_ID, id) diff --git a/llama.cpp b/llama.cpp index 08c44526ffcf5..aea08fa521bc1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -339,7 +339,7 @@ enum llm_kv { LLM_KV_TOKENIZER_SCORES, LLM_KV_TOKENIZER_MERGES, LLM_KV_TOKENIZER_BOS_ID, - LLM_KV_TOKENIZER_EOS_ID, //compatibility with previous versions + LLM_KV_TOKENIZER_EOS_ID, LLM_KV_TOKENIZER_UNK_ID, LLM_KV_TOKENIZER_SEP_ID, LLM_KV_TOKENIZER_PAD_ID, @@ -354,7 +354,6 @@ enum llm_kv { LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, LLM_KV_TOKENIZER_EOT_ID, - LLM_KV_TOKENIZER_EOS_ID_LIST }; static const std::map LLM_KV_NAMES = { @@ -441,7 +440,6 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, - { LLM_KV_TOKENIZER_EOS_ID_LIST, "tokenizer.ggml.eos_token_id_list" }, }; struct LLM_KV { @@ -2365,7 +2363,6 @@ struct llama_vocab { id special_pad_id = -1; id special_cls_id = -1; id special_mask_id = -1; - std::set special_eos_id_list; id linefeed_id = 13; id special_prefix_id = -1; @@ -5135,24 +5132,6 @@ static void llm_load_vocab( } } - const int eos_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_EOS_ID_LIST).c_str()); - if (eos_idx == -1) { - vocab.special_eos_id_list.clear(); - vocab.special_eos_id_list.insert(vocab.special_eos_id); - } else { - const uint32_t n_eos = gguf_get_arr_n(ctx, eos_idx); - const int* eos_tokens = (const int*)gguf_get_arr_data(ctx, eos_idx); - if (n_eos > 0) { - vocab.special_eos_id_list.clear(); - } else { - vocab.special_eos_id_list.clear(); - vocab.special_eos_id_list.insert(vocab.special_eos_id); - } - for (uint32_t i = 0; i < n_eos; ++i) { - vocab.special_eos_id_list.insert(eos_tokens[i]); - } - } - // Handle add_bos_token and add_eos_token { bool temp = true; @@ -5342,11 +5321,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { // special tokens if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } - if (!vocab.special_eos_id_list.empty()) { - for (auto it = vocab.special_eos_id_list.begin(); it != vocab.special_eos_id_list.end(); ++it) { - LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, *it, vocab.id_to_token[*it].text.c_str() ); - } - } + if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } @@ -13783,8 +13758,8 @@ struct llm_tokenizer_bpe { bool append_eos(std::vector & output) const { if (vocab.tokenizer_add_eos) { - GGML_ASSERT(!vocab.special_eos_id_list.empty()); - output.insert(output.end(), vocab.special_eos_id_list.begin(), vocab.special_eos_id_list.end()); + GGML_ASSERT(vocab.special_eos_id != -1); + output.push_back(vocab.special_eos_id); return true; } return false; @@ -13797,7 +13772,7 @@ struct llm_tokenizer_bpe { "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " "Are you sure this is what you want?\n", __FUNCTION__); } - if (vocab.tokenizer_add_eos && output.size() >= 2 && vocab.special_eos_id_list.find(*(output.end()-2)) != vocab.special_eos_id_list.end()) { + if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) { LLAMA_LOG_WARN( "%s: Added a EOS token to the prompt as specified by the model but the prompt " "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. " @@ -14267,8 +14242,8 @@ static std::vector llama_tokenize_internal(const llama_vocab & } if (add_special && vocab.tokenizer_add_eos) { - GGML_ASSERT(!vocab.special_eos_id_list.empty()); - output.insert(output.end(), vocab.special_eos_id_list.begin(), vocab.special_eos_id_list.end()); + GGML_ASSERT(vocab.special_eos_id != -1); + output.push_back(vocab.special_eos_id); } // add suffix to chatglm3 if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM3) { @@ -17269,10 +17244,6 @@ int32_t llama_n_vocab(const struct llama_model * model) { return model->hparams.n_vocab; } -int32_t llama_n_eos(const struct llama_model * model) { - return model->vocab.special_eos_id_list.size(); -} - int32_t llama_n_ctx_train(const struct llama_model * model) { return model->hparams.n_ctx_train; } @@ -18861,8 +18832,21 @@ llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_to } bool llama_token_is_eog(const struct llama_model * model, llama_token token) { + auto arch_name = llama_model_arch_name(model->arch); + auto vocab_type = model->vocab.type; + if (strcmp(arch_name, "chatglm") == 0) { + if (LLAMA_VOCAB_TYPE_BPE == vocab_type) { // glm4 + return token != -1 && ( + token == llama_token_eos(model) || + token == llama_token_eot(model) || + token == 151329 || + token == 151336 || + token == 151338 + ); + } + } return token != -1 && ( - model->vocab.special_eos_id_list.count(token) || + token == llama_token_eos(model) || token == llama_token_eot(model) ); } @@ -18875,11 +18859,8 @@ llama_token llama_token_bos(const struct llama_model * model) { return model->vocab.special_bos_id; } -void llama_token_eos(const struct llama_model * model, llama_token* token_list) { - int ind = 0; - for (auto it = model->vocab.special_eos_id_list.begin(); it != model->vocab.special_eos_id_list.end(); ++it) { - token_list[ind++] = *it; - } +llama_token llama_token_eos(const struct llama_model * model) { + return model->vocab.special_eos_id; } llama_token llama_token_cls(const struct llama_model * model) { @@ -19253,7 +19234,10 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; } - } else if (tmpl == "chatglm3" || tmpl.find("[gMASK]sop") != std::string::npos) { + } else if (tmpl == "chatglm3" || + (tmpl.find("add_generation_prompt") != std::string::npos && + tmpl.find("for message in messages") != std::string::npos && + tmpl.find("loop.first") != std::string::npos)) { // chatglm3-6b ss << "[gMASK]" << "sop"; for (auto message : chat) { @@ -19263,7 +19247,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>"; } - } else if (tmpl == "chatglm4" || tmpl.find("[gMASK]") != std::string::npos) { + } else if (tmpl == "ChatGLM4") { ss << "[gMASK]" << ""; for (auto message : chat) { std::string role(message->role); diff --git a/llama.h b/llama.h index e195d2d9e99c2..51c14cb4e0651 100644 --- a/llama.h +++ b/llama.h @@ -448,7 +448,6 @@ extern "C" { LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); - LLAMA_API int32_t llama_n_eos (const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_n_embd (const struct llama_model * model); LLAMA_API int32_t llama_n_layer (const struct llama_model * model); @@ -856,7 +855,7 @@ extern "C" { // Special tokens LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence - LLAMA_API void llama_token_eos(const struct llama_model * model, llama_token* token_list); // end-of-sentence + LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 399dc57b9374f..0fe4d29674269 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -60,7 +60,7 @@ int main(void) { // ChatGLM3 "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", // ChatGLM4 - "chatglm4", + "ChatGLM4", }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B From 9396c7bbaf93a0d0f49c6b92cbb8635e5a2e81be Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Wed, 26 Jun 2024 02:16:12 +0000 Subject: [PATCH 15/28] set <|endoftext|> as eos and <|user|> as eot --- convert-hf-to-gguf.py | 2 +- llama.cpp | 13 ------------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 4f6308d467d2e..363c09720c749 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -3048,9 +3048,9 @@ def set_vocab(self): special_vocab.chat_template = "ChatGLM4" special_vocab.merges = merges # only add special tokens when they were not already loaded from config.json - # if len(special_vocab.special_token_ids) == 0: special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # this one is usually not in config.json anyway special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) special_vocab.add_to_gguf(self.gguf_writer) diff --git a/llama.cpp b/llama.cpp index aea08fa521bc1..7a48c11153473 100644 --- a/llama.cpp +++ b/llama.cpp @@ -18832,19 +18832,6 @@ llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_to } bool llama_token_is_eog(const struct llama_model * model, llama_token token) { - auto arch_name = llama_model_arch_name(model->arch); - auto vocab_type = model->vocab.type; - if (strcmp(arch_name, "chatglm") == 0) { - if (LLAMA_VOCAB_TYPE_BPE == vocab_type) { // glm4 - return token != -1 && ( - token == llama_token_eos(model) || - token == llama_token_eot(model) || - token == 151329 || - token == 151336 || - token == 151338 - ); - } - } return token != -1 && ( token == llama_token_eos(model) || token == llama_token_eot(model) From 0595f03dd171cdad54aafe79199f223d7a2e3f0b Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Wed, 26 Jun 2024 05:58:13 +0000 Subject: [PATCH 16/28] fix chat template bug --- convert-hf-to-gguf.py | 2 +- llama.cpp | 7 ++----- tests/test-chat-template.cpp | 2 +- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 363c09720c749..f500b3492615a 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -3045,7 +3045,7 @@ def set_vocab(self): self.gguf_writer.add_token_types(toktypes) special_vocab = gguf.SpecialVocab(dir_model, load_merges=False) - special_vocab.chat_template = "ChatGLM4" + special_vocab.chat_template = "chatglm4" special_vocab.merges = merges # only add special tokens when they were not already loaded from config.json special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) diff --git a/llama.cpp b/llama.cpp index 76b5c3d91ecac..4abdfa37a7943 100644 --- a/llama.cpp +++ b/llama.cpp @@ -19801,10 +19801,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; } - } else if (tmpl == "chatglm3" || - (tmpl.find("add_generation_prompt") != std::string::npos && - tmpl.find("for message in messages") != std::string::npos && - tmpl.find("loop.first") != std::string::npos)) { + } else if (tmpl == "chatglm3" || tmpl.find("[gMASK]sop") != std::string::npos) { // chatglm3-6b ss << "[gMASK]" << "sop"; for (auto message : chat) { @@ -19814,7 +19811,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>"; } - } else if (tmpl == "ChatGLM4") { + } else if (tmpl == "chatglm4" || tmpl.find("[gMASK]") != std::string::npos) { ss << "[gMASK]" << ""; for (auto message : chat) { std::string role(message->role); diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index bcdfeee0e3172..e843be749814d 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -61,7 +61,7 @@ int main(void) { // ChatGLM3 "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", // ChatGLM4 - "ChatGLM4", + "chatglm4", }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B From 7357273e088311189dcdc61e28b78ab8bdc74616 Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Thu, 27 Jun 2024 02:55:52 +0000 Subject: [PATCH 17/28] add comment to glm prefix and suffix --- convert-hf-to-gguf.py | 2 +- gguf-py/gguf/constants.py | 2 +- gguf-py/pyproject.toml | 2 +- llama.cpp | 20 ++++++++++---------- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index f500b3492615a..70ce29f720168 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -3043,12 +3043,12 @@ def set_vocab(self): self.gguf_writer.add_tokenizer_pre(tokpre) self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_add_bos_token(False) special_vocab = gguf.SpecialVocab(dir_model, load_merges=False) special_vocab.chat_template = "chatglm4" special_vocab.merges = merges # only add special tokens when they were not already loaded from config.json - special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # this one is usually not in config.json anyway diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1e0afe9d39066..80c3478d2969c 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -104,7 +104,7 @@ class Tokenizer: ADD_BOS = "tokenizer.ggml.add_bos_token" ADD_EOS = "tokenizer.ggml.add_eos_token" ADD_PREFIX = "tokenizer.ggml.add_space_prefix" - REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces" + REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces" PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap" HF_JSON = "tokenizer.huggingface.json" RWKV = "tokenizer.rwkv.world" diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index 62129126bdddc..36e63ee3b7cd6 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gguf" -version = "0.9.1" +version = "0.9.0" description = "Read and write ML models in GGUF for GGML" authors = ["GGML "] packages = [ diff --git a/llama.cpp b/llama.cpp index 4abdfa37a7943..2becfee0e66e6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -14745,10 +14745,10 @@ static std::vector llama_tokenize_internal(const llama_vocab & } // add prefix to chatglm3 if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM3) { - output.push_back(64790); - output.push_back(64792); - output.push_back(64795); - output.push_back(30910); + output.push_back(64790); // [gMask] + output.push_back(64792); // sop + output.push_back(64795); // <|user|> + output.push_back(30910); // \n output.push_back(13); } @@ -14787,7 +14787,7 @@ static std::vector llama_tokenize_internal(const llama_vocab & } // add suffix to chatglm3 if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM3) { - output.push_back(64796); + output.push_back(64796); // <|assistant|> } } break; case LLAMA_VOCAB_TYPE_BPE: @@ -14799,10 +14799,10 @@ static std::vector llama_tokenize_internal(const llama_vocab & } // add prefix to chatglm4 if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM4) { - output.push_back(151331); - output.push_back(151333); - output.push_back(151336); - output.push_back(198); + output.push_back(151331); // [gMASK] + output.push_back(151333); // + output.push_back(151336); // <|user|> + output.push_back(198); // \n } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -14823,7 +14823,7 @@ static std::vector llama_tokenize_internal(const llama_vocab & } // add suffix to chatglm4 if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM4) { - output.push_back(151337); + output.push_back(151337); // <|assistant|> } } break; case LLAMA_VOCAB_TYPE_WPM: From e9e47eb9714d6ff212cd05765177bd7e1889d500 Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Thu, 27 Jun 2024 06:27:35 +0000 Subject: [PATCH 18/28] fix conflicts and add rope_ratio & ChatGLMForConditionalGeneration --- convert-hf-to-gguf.py | 5 +++-- src/llama.cpp | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 70ce29f720168..c9e6ebf30f5ba 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2896,7 +2896,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] -@Model.register("ChatGLMModel") +@Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(Model): model_arch = gguf.MODEL_ARCH.CHATGLM @@ -3043,7 +3043,6 @@ def set_vocab(self): self.gguf_writer.add_tokenizer_pre(tokpre) self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) - self.gguf_writer.add_add_bos_token(False) special_vocab = gguf.SpecialVocab(dir_model, load_merges=False) special_vocab.chat_template = "chatglm4" @@ -3070,6 +3069,8 @@ def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_rope_dimension_count(64) self.gguf_writer.add_add_bos_token(False) + self.gguf_writer.add_rope_freq_base(self.hparams.get("rope_ratio", 10000)) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused diff --git a/src/llama.cpp b/src/llama.cpp index 9230d89826603..ad17d5ab55363 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -12281,9 +12281,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - NULL, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); From bbe1926fac33becf04784285b82c5854c0ce312e Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Fri, 28 Jun 2024 06:32:22 +0000 Subject: [PATCH 19/28] fix chat template bug --- src/llama.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 36c5e681f256c..8a53df14a896b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5103,6 +5103,7 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "chatglm-bpe") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; + vocab.special_bos_id = -1; } else if ( tokenizer_pre == "viking") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING; @@ -19828,7 +19829,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; } - } else if (tmpl == "chatglm3" || tmpl.find("[gMASK]sop") != std::string::npos) { + } else if (tmpl.find("chatglm3") != std::string::npos || tmpl.find("[gMASK]sop") != std::string::npos) { // chatglm3-6b ss << "[gMASK]" << "sop"; for (auto message : chat) { @@ -19838,7 +19839,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>"; } - } else if (tmpl == "chatglm4" || tmpl.find("[gMASK]") != std::string::npos) { + } else if (tmpl.find("chatglm4") != std::string::npos || tmpl.find("[gMASK]") != std::string::npos) { ss << "[gMASK]" << ""; for (auto message : chat) { std::string role(message->role); From d07f0a90c36379c6220e1ecab42208947f4fd95f Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Mon, 1 Jul 2024 02:23:19 +0000 Subject: [PATCH 20/28] fix codestyle --- gguf-py/gguf/constants.py | 54 ++++++++++++++++++------------------ src/llama.cpp | 5 ++-- tests/test-chat-template.cpp | 2 +- 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 80c3478d2969c..1a7f2628e0d23 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -87,35 +87,35 @@ class SSM: TIME_STEP_RANK = "{arch}.ssm.time_step_rank" class Tokenizer: - MODEL = "tokenizer.ggml.model" - PRE = "tokenizer.ggml.pre" - LIST = "tokenizer.ggml.tokens" - TOKEN_TYPE = "tokenizer.ggml.token_type" - TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types - SCORES = "tokenizer.ggml.scores" - MERGES = "tokenizer.ggml.merges" - BOS_ID = "tokenizer.ggml.bos_token_id" - EOS_ID = "tokenizer.ggml.eos_token_id" - UNK_ID = "tokenizer.ggml.unknown_token_id" - SEP_ID = "tokenizer.ggml.seperator_token_id" - PAD_ID = "tokenizer.ggml.padding_token_id" - CLS_ID = "tokenizer.ggml.cls_token_id" - MASK_ID = "tokenizer.ggml.mask_token_id" - ADD_BOS = "tokenizer.ggml.add_bos_token" - ADD_EOS = "tokenizer.ggml.add_eos_token" - ADD_PREFIX = "tokenizer.ggml.add_space_prefix" - REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces" + MODEL = "tokenizer.ggml.model" + PRE = "tokenizer.ggml.pre" + LIST = "tokenizer.ggml.tokens" + TOKEN_TYPE = "tokenizer.ggml.token_type" + TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types + SCORES = "tokenizer.ggml.scores" + MERGES = "tokenizer.ggml.merges" + BOS_ID = "tokenizer.ggml.bos_token_id" + EOS_ID = "tokenizer.ggml.eos_token_id" + UNK_ID = "tokenizer.ggml.unknown_token_id" + SEP_ID = "tokenizer.ggml.seperator_token_id" + PAD_ID = "tokenizer.ggml.padding_token_id" + CLS_ID = "tokenizer.ggml.cls_token_id" + MASK_ID = "tokenizer.ggml.mask_token_id" + ADD_BOS = "tokenizer.ggml.add_bos_token" + ADD_EOS = "tokenizer.ggml.add_eos_token" + ADD_PREFIX = "tokenizer.ggml.add_space_prefix" + REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces" PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap" - HF_JSON = "tokenizer.huggingface.json" - RWKV = "tokenizer.rwkv.world" - CHAT_TEMPLATE = "tokenizer.chat_template" - CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}" - CHAT_TEMPLATES = "tokenizer.chat_templates" + HF_JSON = "tokenizer.huggingface.json" + RWKV = "tokenizer.rwkv.world" + CHAT_TEMPLATE = "tokenizer.chat_template" + CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}" + CHAT_TEMPLATES = "tokenizer.chat_templates" # FIM/Infill special tokens constants - PREFIX_ID = "tokenizer.ggml.prefix_token_id" - SUFFIX_ID = "tokenizer.ggml.suffix_token_id" - MIDDLE_ID = "tokenizer.ggml.middle_token_id" - EOT_ID = "tokenizer.ggml.eot_token_id" + PREFIX_ID = "tokenizer.ggml.prefix_token_id" + SUFFIX_ID = "tokenizer.ggml.suffix_token_id" + MIDDLE_ID = "tokenizer.ggml.middle_token_id" + EOT_ID = "tokenizer.ggml.eot_token_id" # # recommended mapping of model tensor names for storage in gguf diff --git a/src/llama.cpp b/src/llama.cpp index 8a53df14a896b..b36208fb70087 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5222,7 +5222,6 @@ static void llm_load_vocab( vocab.special_eot_id = 107; } } - try { vocab.linefeed_id = llama_byte_to_token(vocab, '\n'); } catch (const std::exception & e) { @@ -19829,7 +19828,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; } - } else if (tmpl.find("chatglm3") != std::string::npos || tmpl.find("[gMASK]sop") != std::string::npos) { + } else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) { // chatglm3-6b ss << "[gMASK]" << "sop"; for (auto message : chat) { @@ -19839,7 +19838,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>"; } - } else if (tmpl.find("chatglm4") != std::string::npos || tmpl.find("[gMASK]") != std::string::npos) { + } else if (tmpl == "chaglm4" || tmpl_contains("[gMASK]")) { ss << "[gMASK]" << ""; for (auto message : chat) { std::string role(message->role); diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index e843be749814d..05640b79fa06f 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -61,7 +61,7 @@ int main(void) { // ChatGLM3 "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", // ChatGLM4 - "chatglm4", + u8"[gMASK]{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B From 5e9dba664af717b32ffd1c8f8f0a27c4676cae08 Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Mon, 1 Jul 2024 02:50:33 +0000 Subject: [PATCH 21/28] fix conflicts --- src/llama.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index b284d59327b42..61721276d6885 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2084,7 +2084,6 @@ enum e_model { MODEL_16x12B, MODEL_10B_128x3_66B, MODEL_57B_A14B, - MODEL_9B, MODEL_27B, }; @@ -4324,7 +4323,6 @@ static const char * llama_model_type_name(e_model type) { case MODEL_16x12B: return "16x12B"; case MODEL_10B_128x3_66B: return "10B+128x3.66B"; case MODEL_57B_A14B: return "57B.A14B"; - case MODEL_9B: return "9B"; case MODEL_27B: return "27B"; default: return "?B"; } @@ -5011,9 +5009,7 @@ static void llm_load_vocab( if (merges_keyidx == -1) { throw std::runtime_error("cannot find tokenizer merges in model file\n"); } - printf("merges_keyidx: %d\n", merges_keyidx); const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - printf("n_merges: %d\n", n_merges); for (int i = 0; i < n_merges; i++) { const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); From 865dd03f431c7dd72fff0a1e07105e0e7fc655da Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Mon, 1 Jul 2024 03:31:50 +0000 Subject: [PATCH 22/28] modified the general name of glm model --- convert-hf-to-gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index a56ce6c88ac69..0838bcb3de3ae 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -3104,7 +3104,7 @@ def set_vocab(self): special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.hparams.get("_name_or_path").split("/")[1]) # THUDM/glm4-9b-chat or THUDM/chatglm3-6b n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) n_head_kv = self.hparams.get("multi_query_group_num", n_head) From 80b381b940f772084cdc62284a6d6ba01dcecdba Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Wed, 3 Jul 2024 02:55:47 +0000 Subject: [PATCH 23/28] fix conflicts --- include/llama.h | 3 +-- src/llama.cpp | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/llama.h b/include/llama.h index a5fb088eae5e6..16095d9a7e74a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -91,8 +91,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, LLAMA_VOCAB_PRE_TYPE_VIKING = 18, - LLAMA_VOCAB_PRE_TYPE_VIKING = 19, - LLAMA_VOCAB_PRE_TYPE_JAIS = 20, + LLAMA_VOCAB_PRE_TYPE_JAIS = 19, }; // note: these values should be synchronized with ggml_rope diff --git a/src/llama.cpp b/src/llama.cpp index b3b95f72ca108..cf0add3ab076a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -12567,6 +12567,7 @@ struct llm_build_context { cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, From bf54db218e27b723e5021a6b0749f3ab6560b3d3 Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Wed, 3 Jul 2024 03:16:09 +0000 Subject: [PATCH 24/28] remove prefix and suffix --- src/llama.cpp | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index cf0add3ab076a..b6b91c3322f27 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5200,13 +5200,6 @@ static void llm_load_vocab( vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; vocab.tokenizer_add_bos = true; vocab.tokenizer_add_eos = false; - // chatglm3 needs to preprocess prefix and suffix - if (tokenizer_pre == "chatglm-spm") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM3; - vocab.tokenizer_add_bos = false; - vocab.tokenizer_add_eos = false; - vocab.tokenizer_add_space_prefix = false; - } } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; vocab.tokenizer_add_bos = true; @@ -15190,14 +15183,6 @@ static std::vector llama_tokenize_internal(const llama_vocab & output.push_back(vocab.special_bos_id); is_prev_special = true; } - // add prefix to chatglm3 - if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM3) { - output.push_back(64790); // [gMask] - output.push_back(64792); // sop - output.push_back(64795); // <|user|> - output.push_back(30910); // \n - output.push_back(13); - } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -15232,10 +15217,6 @@ static std::vector llama_tokenize_internal(const llama_vocab & GGML_ASSERT(vocab.special_eos_id != -1); output.push_back(vocab.special_eos_id); } - // add suffix to chatglm3 - if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM3) { - output.push_back(64796); // <|assistant|> - } } break; case LLAMA_VOCAB_TYPE_BPE: { @@ -15244,13 +15225,6 @@ static std::vector llama_tokenize_internal(const llama_vocab & if (add_special) { tokenizer.append_bos(output); } - // add prefix to chatglm4 - if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM4) { - output.push_back(151331); // [gMASK] - output.push_back(151333); // - output.push_back(151336); // <|user|> - output.push_back(198); // \n - } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); @@ -15268,10 +15242,6 @@ static std::vector llama_tokenize_internal(const llama_vocab & tokenizer.append_eos(output); tokenizer.check_double_bos_eos(output); } - // add suffix to chatglm4 - if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM4) { - output.push_back(151337); // <|assistant|> - } } break; case LLAMA_VOCAB_TYPE_WPM: { From bce74d8212b5548409f6ef048b78cc8b6a06d92b Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Wed, 3 Jul 2024 08:57:03 +0000 Subject: [PATCH 25/28] use normal glm4 chattempalte & use LLM_FFN_SWIGLU in phi3 --- convert-hf-to-gguf.py | 1 - src/llama.cpp | 19 ++++++------------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 15603007183b6..8c0fa5d8ee8af 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -3209,7 +3209,6 @@ def set_vocab(self): self.gguf_writer.add_token_types(toktypes) special_vocab = gguf.SpecialVocab(dir_model, load_merges=False) - special_vocab.chat_template = "chatglm4" special_vocab.merges = merges # only add special tokens when they were not already loaded from config.json special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) diff --git a/src/llama.cpp b/src/llama.cpp index b6b91c3322f27..eb1bde2697c88 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10326,19 +10326,12 @@ struct llm_build_context { // special-case: the up and gate tensors are merged into a single tensor // TOOD: support into llm_build_ffn { - struct ggml_tensor* up = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); - cb(up, "ffn_up", il); - - auto g = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), 0)); - auto y = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), up->nb[1] / 2)); - - y = ggml_mul(ctx0, y, ggml_silu(ctx0, g)); - cb(y, "ffn_gate", il); - - auto down = ggml_mul_mat(ctx0, model.layers[il].ffn_down, y); - cb(down, "ffn_down", il); - - cur = down; + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); } From 3be4270fc8e86a1f8e4f5e37a24994cd6ef757fb Mon Sep 17 00:00:00 2001 From: Umpire2018 <138990495+Umpire2018@users.noreply.github.com> Date: Fri, 5 Jul 2024 14:44:29 +0000 Subject: [PATCH 26/28] fix: resolve Flake8 errors in `convert-hf-to-gguf.py` - Fix E302 by adding two blank lines before top-level function definitions - Replace print statements to fix NP100 - Fix E303 by ensuring only one blank line between lines of code --- convert-hf-to-gguf.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 8c0fa5d8ee8af..5b85f49decf57 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -3060,6 +3060,7 @@ def write_tensors(self): super().write_tensors() self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias) + @Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(Model): model_arch = gguf.MODEL_ARCH.CHATGLM @@ -3077,8 +3078,6 @@ def set_vocab_chatglm3(self): assert max(tokenizer.get_vocab().values()) < vocab_size role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"] special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens - print(vocab_size) - print(max(tokenizer.get_vocab().values())) for token_id in range(vocab_size): piece = tokenizer._convert_id_to_token(token_id) if token_id == 0: @@ -3234,7 +3233,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_add_bos_token(False) self.gguf_writer.add_rope_freq_base(self.hparams.get("rope_ratio", 10000)) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused From 5b760f26a4c31055823632ad8e0a528c6da4422c Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Sun, 7 Jul 2024 10:27:05 +0000 Subject: [PATCH 27/28] fix rope ratio to solve incorrect answers --- convert-hf-to-gguf.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 5b85f49decf57..1ae7abbaf8cea 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -3231,7 +3231,10 @@ def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_rope_dimension_count(64) self.gguf_writer.add_add_bos_token(False) - self.gguf_writer.add_rope_freq_base(self.hparams.get("rope_ratio", 10000)) + rope_ratio = 10000 + if "rope_ratio" in self.hparams: + rope_ratio = rope_ratio * self.hparams["rope_ratio"] + self.gguf_writer.add_rope_freq_base(rope_ratio) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused From 4e85b06de7a832ad1525a4b6204ffff84c6999a4 Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Sun, 7 Jul 2024 11:42:54 +0000 Subject: [PATCH 28/28] fix by comments --- convert_hf_to_gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ffa66cb6dc246..6ee41d3a118e5 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3349,10 +3349,10 @@ def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_rope_dimension_count(64) self.gguf_writer.add_add_bos_token(False) - rope_ratio = 10000 + rope_freq = 10000 if "rope_ratio" in self.hparams: - rope_ratio = rope_ratio * self.hparams["rope_ratio"] - self.gguf_writer.add_rope_freq_base(rope_ratio) + rope_freq = rope_freq * self.hparams["rope_ratio"] + self.gguf_writer.add_rope_freq_base(rope_freq) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused