diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ebb5ca376133b..4cd28092472e5 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3395,6 +3395,135 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter name = name.removeprefix("transformer.") return [(self.map_tensor_name(name), data_torch)] + +@Model.register("NemotronForCausalLM") +class Nemotron4Model(Model): + model_arch = gguf.MODEL_ARCH.NEMOTRON4 + + def set_vocab(self): + # to avoid TypeError: Descriptors cannot be created directly + # exception when importing sentencepiece_model_pb2 + os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + from sentencepiece import SentencePieceProcessor + from sentencepiece import sentencepiece_model_pb2 as model + + tokenizer_path = self.dir_model / 'tokenizer.model' + + if not tokenizer_path.is_file(): + raise FileNotFoundError(f"File not found: {tokenizer_path}") + + sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] + sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) + + assert sentencepiece_model.trainer_spec.model_type == 2 # BPE + + add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix + remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces + + tokenizer = SentencePieceProcessor() + tokenizer.LoadFromFile(str(tokenizer_path)) + + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] + scores: list[float] = [-10000.0] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + + for token_id in range(tokenizer.vocab_size()): + piece = tokenizer.IdToPiece(token_id) + text = piece.encode("utf-8") + score = tokenizer.GetScore(token_id) + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.IsUnknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.IsControl(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.IsUnused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.IsByte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + + added_tokens_file = self.dir_model / 'added_tokens.json' + if added_tokens_file.is_file(): + with open(added_tokens_file, "r", encoding="utf-8") as f: + added_tokens_json = json.load(f) + for key in added_tokens_json: + token_id = added_tokens_json[key] + if (token_id >= vocab_size): + logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + continue + + tokens[token_id] = key.encode("utf-8") + scores[token_id] = -1000.0 + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + + if vocab_size > len(tokens): + pad_count = vocab_size - len(tokens) + logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") + for i in range(1, pad_count + 1): + tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) + scores.append(-1000.0) + toktypes.append(SentencePieceTokenTypes.UNUSED) + + self.gguf_writer.add_tokenizer_model("nemotron") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_add_space_prefix(add_prefix) + self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces) + + special_vocab = gguf.SpecialVocab( + self.dir_model, n_vocab=len(tokens), + special_token_types = ('bos', 'eos', 'eot') + ) + special_vocab._set_special_token("eot", 5) # + special_vocab.add_to_gguf(self.gguf_writer) + + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(False) + + def set_gguf_parameters(self): + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) + self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_rope_dimension_count( + int(self.hparams["partial_rotary_factor"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])), + ) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) + self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.endswith(".layer_norm.weight") or name == "final_layernorm.weight": + logger.info(f"Adding 1.0 to {name} tensor data, see NeMo zero_centered_gamma documentation") + data_torch = data_torch + 1.0 + if name.endswith(".linear_qkv.weight"): + n_head = self.find_hparam(["num_attention_heads"]) + n_head_kv = self.find_hparam(["num_key_value_heads"]) + head_dim = self.hparams["hidden_size"] // n_head + + qkv = data_torch.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head) + q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head) + k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head) + v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head) + data_torch = torch.cat((q, k, v)).reshape_as(data_torch) + + + return [(self.map_tensor_name(name), data_torch)] + + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a95a44237e348..be8b9b762b5c0 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -166,6 +166,7 @@ class MODEL_ARCH(IntEnum): BITNET = auto() T5 = auto() JAIS = auto() + NEMOTRON4 = auto() class MODEL_TENSOR(IntEnum): @@ -293,6 +294,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.T5: "t5", MODEL_ARCH.JAIS: "jais", + MODEL_ARCH.NEMOTRON4: "nemotron4", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -996,6 +998,17 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.NEMOTRON4: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 7264240f5e17a..7dd30b93c4ace 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -75,6 +75,7 @@ class TensorNameMap: "transformer.rms_norm", # Grok "encoder.final_layernorm", # chatglm "transformer.norm", # openelm + "final_layernorm", # nemotron4 ), # Rope frequencies @@ -107,6 +108,7 @@ class TensorNameMap: "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx "encoder.layers.{bid}.input_layernorm", # chatglm "transformer.layers.{bid}.attn_norm", # openelm + "model.layers.{bid}.self_attention.linear_qkv.layer_norm" # nemotron4 ), # Attention norm 2 @@ -131,6 +133,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.qkv_proj", # phi3 "encoder.layers.{bid}.self_attention.query_key_value", # chatglm "transformer.layers.{bid}.attn.qkv_proj", # openelm + "model.layers.{bid}.self_attention.linear_qkv" # nemotron4 ), # Attention query @@ -190,6 +193,7 @@ class TensorNameMap: "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx "encoder.layers.{bid}.self_attention.dense", # chatglm "transformer.layers.{bid}.attn.out_proj", # openelm + "model.layers.{bid}.self_attention.linear_proj", # nemotron4 ), # Attention output norm @@ -227,6 +231,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.rms_norm_2", # Grok "encoder.layers.{bid}.post_attention_layernorm", # chatglm "transformer.layers.{bid}.ffn_norm", # openelm + "model.layers.{bid}.mlp.linear_fc1.layer_norm", # nemotron4 ), # Post feed-forward norm @@ -277,6 +282,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 "model.layers.{bid}.residual_mlp.w3", # arctic "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm + "model.layers.{bid}.mlp.linear_fc1", # nemotron4 ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -347,6 +353,7 @@ class TensorNameMap: "model.layers.{bid}.residual_mlp.w2", # arctic "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm + "model.layers.{bid}.mlp.linear_fc2", # nemotron4 ), MODEL_TENSOR.FFN_DOWN_EXP: ( diff --git a/include/llama.h b/include/llama.h index 3970c3aebcd62..de16434c31d0a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -68,6 +68,7 @@ extern "C" { LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram + LLAMA_VOCAB_TYPE_NTN = 5, // Nemotron tokenizer based on SentencePiece BPE }; // pre-tokenization types diff --git a/src/llama.cpp b/src/llama.cpp index ed77ed918e554..621c1d8756a51 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -239,6 +239,7 @@ enum llm_arch { LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_JAIS, + LLM_ARCH_NEMOTRON4, LLM_ARCH_UNKNOWN, }; @@ -283,6 +284,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_NEMOTRON4, "nemotron4" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1296,6 +1298,20 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, }, }, + { + LLM_ARCH_NEMOTRON4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -4562,6 +4578,7 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){ case LLAMA_VOCAB_TYPE_BPE: return "BPE"; case LLAMA_VOCAB_TYPE_WPM: return "WPM"; case LLAMA_VOCAB_TYPE_UGM: return "UGM"; + case LLAMA_VOCAB_TYPE_NTN: return "NTN"; default: return "unknown"; } } @@ -5217,6 +5234,12 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_NEMOTRON4: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + model.type = e_model::MODEL_UNKNOWN; + } + default: (void)0; } @@ -5361,6 +5384,17 @@ static void llm_load_vocab( } #endif } + } else if (tokenizer_model == "nemotron") { + vocab.type = LLAMA_VOCAB_TYPE_NTN; + + // default special tokens + vocab.special_bos_id = 2; + vocab.special_eos_id = 3; + vocab.special_unk_id = 1; + vocab.special_sep_id = -1; + vocab.special_pad_id = 0; + vocab.special_cls_id = -1; + vocab.special_mask_id = -1; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -5469,6 +5503,12 @@ static void llm_load_vocab( vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; vocab.tokenizer_add_bos = false; vocab.tokenizer_add_eos = true; + } else if (vocab.type == LLAMA_VOCAB_TYPE_NTN) { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.tokenizer_add_space_prefix = true; + vocab.tokenizer_clean_spaces = false; + vocab.tokenizer_add_bos = true; + vocab.tokenizer_add_eos = false; } else { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } @@ -5528,7 +5568,7 @@ static void llm_load_vocab( GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' - if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { + if (vocab.type == LLAMA_VOCAB_TYPE_SPM || vocab.type == LLAMA_VOCAB_TYPE_NTN) { // For Fill-In-the-Middle (FIM)/infill models which where converted // prior to support of FIM special tokens in GGUF, the following // will allow those models to continue to work. The general names @@ -7504,6 +7544,38 @@ static bool llm_load_tensors( layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); } } break; + case LLM_ARCH_NEMOTRON4: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -13607,6 +13679,110 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_nemotron4() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // ffn + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -13854,6 +14030,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_jais(); } break; + case LLM_ARCH_NEMOTRON4: + { + result = llm.build_nemotron4(); + } break; default: GGML_ASSERT(false); } @@ -15148,7 +15328,8 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { const auto & token_data = vocab.id_to_token.at(id); switch (llama_vocab_get_type(vocab)) { case LLAMA_VOCAB_TYPE_SPM: - case LLAMA_VOCAB_TYPE_UGM: { + case LLAMA_VOCAB_TYPE_UGM: + case LLAMA_VOCAB_TYPE_NTN: { auto buf = token_data.text.substr(3, 2); return strtol(buf.c_str(), NULL, 16); } @@ -15169,7 +15350,8 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { static const char * hex = "0123456789ABCDEF"; switch (llama_vocab_get_type(vocab)) { case LLAMA_VOCAB_TYPE_SPM: - case LLAMA_VOCAB_TYPE_UGM: { + case LLAMA_VOCAB_TYPE_UGM: + case LLAMA_VOCAB_TYPE_NTN: { const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 }; auto token = vocab.token_to_id.find(buf); if (token != vocab.token_to_id.end()) { @@ -15202,6 +15384,7 @@ struct llm_symbol { index next; const char * text; size_t n; + bool freeze; }; static_assert(std::is_trivially_copyable::value, "llm_symbol is not trivially copyable"); @@ -15532,7 +15715,7 @@ struct llm_tokenizer_bpe { size_t offset = 0; if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) { - symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); + symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size(), false}); offset = word.size(); } @@ -16148,6 +16331,237 @@ struct llm_tokenizer_ugm { struct naive_trie token_matcher; }; +struct llm_tokenizer_ntn { + llm_tokenizer_ntn(const llama_vocab & vocab) : vocab(vocab) { + for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) { + const auto &token_data = vocab.id_to_token[id]; + + if (llama_is_user_defined_token(vocab, id)) { + user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size()); + } + } + } + + void tokenize(const std::string & text, std::vector & output) { + // normalize the input first + std::string normalized; + normalize(text, &normalized); + size_t input_len = normalized.size(); + if (input_len == 0) { + return; + } + + // split string into utf8 chars + int index = 0; + size_t offs = 0; + while (offs < normalized.size()) { + llm_symbol sym; + sym.freeze = false; + auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&normalized[offs], normalized.size() - offs); + size_t len = 0; + if (user_defined_token_match.second > 0) { + sym.freeze = true; + len = user_defined_token_match.second; + } else { + len = utf8_len(normalized[offs]); + } + sym.text = normalized.c_str() + offs; + sym.n = std::min(len, normalized.size() - offs); + offs += sym.n; + sym.prev = index - 1; + sym.next = offs == normalized.size() ? -1 : index + 1; + index++; + symbols.emplace_back(sym); + } + + // seed the work queue with all possible 2-character tokens. + for (size_t i = 1; i < symbols.size(); ++i) { + try_add_bigram(i - 1, i); + } + + // keep substituting the highest frequency pairs for as long as we can. + while (!work_queue.empty()) { + auto bigram = work_queue.top(); + work_queue.pop(); + + auto & left_sym = symbols[bigram.left]; + auto & right_sym = symbols[bigram.right]; + + // if one of the symbols already got merged, skip it. + if (left_sym.n == 0 || right_sym.n == 0 || + left_sym.n + right_sym.n != bigram.size) { + continue; + } + + // merge the right sym into the left one + left_sym.n += right_sym.n; + right_sym.n = 0; + + //LLAMA_LOG_INFO("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.normalized, bigram.size); + + // remove the right sym from the chain + left_sym.next = right_sym.next; + if (right_sym.next >= 0) { + symbols[right_sym.next].prev = bigram.left; + } + + // find more substitutions + try_add_bigram(left_sym.prev, bigram.left); + try_add_bigram(bigram.left, left_sym.next); + } + + for (int i = 0; i != -1; i = symbols[i].next) { + auto & symbol = symbols[i]; + resegment(symbol, output); + } + } + +private: + void resegment(llm_symbol & symbol, std::vector & output) { + auto text = std::string(symbol.text, symbol.n); + auto token = vocab.token_to_id.find(text); + + // Do we need to support is_unused? + if (token != vocab.token_to_id.end()) { + output.push_back((*token).second); + return; + } + + const auto p = rev_merge.find(text); + + if (p == rev_merge.end()) { + // output any symbols that did not form tokens as bytes. + output.reserve(output.size() + symbol.n); + for (int j = 0; j < (int)symbol.n; ++j) { + llama_vocab::id token_id = llama_byte_to_token(vocab, symbol.text[j]); + output.push_back(token_id); + } + return; + } + + resegment(symbols[p->second.first], output); + resegment(symbols[p->second.second], output); + } + + void try_add_bigram(int left, int right) { + if (left == -1 || right == -1 || symbols[left].freeze || symbols[right].freeze) { + return; + } + + const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n); + auto token = vocab.token_to_id.find(text); + + if (token == vocab.token_to_id.end()) { + return; + } + + if (static_cast((*token).second) >= vocab.id_to_token.size()) { + return; + } + + const auto & tok_data = vocab.id_to_token[(*token).second]; + + llm_bigram_spm bigram; + bigram.left = left; + bigram.right = right; + bigram.score = tok_data.score; + bigram.size = text.size(); + + work_queue.push(bigram); + + // Do we need to support is_unused? + rev_merge[text] = std::make_pair(left, right); + } + + const llama_vocab & vocab; + + std::vector symbols; + llm_bigram_spm::queue work_queue; + + std::map> rev_merge; + + // helper structure for returning normalization results + struct normalization_result { + const char * normalized; + size_t normalized_len; + size_t consumed_input; + }; + + void normalize(const std::string& input, std::string * normalized) { + normalized->clear(); + normalized->reserve(input.size() * 3); + + const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " "; + + bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; + bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; + bool shall_merge_spaces = vocab.tokenizer_remove_extra_whitespaces; + + bool is_space_prepended = false; + bool processing_non_ws = false; + + size_t input_len = input.size(); + + for (size_t input_offset = 0; input_offset < input_len; ) { + auto norm_res = normalize_prefix(input, input_offset); + for (size_t i = 0; i < norm_res.normalized_len; i++) { + char c = norm_res.normalized[i]; + if (c != ' ') { + if (!processing_non_ws) { + processing_non_ws = true; + if ((shall_prepend_space && !is_space_prepended) || shall_merge_spaces) { + normalized->append(space); + is_space_prepended = true; + } + } + normalized->push_back(c); + } else { + if (processing_non_ws) { + processing_non_ws = false; + } + if (!shall_merge_spaces) { + normalized->append(space); + } + } + } + + input_offset += norm_res.consumed_input; + } + + if (shall_append_space) { + normalized->append(space); + } + } + + struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) { + if (input_offset == input.size()) { + return { &input[input_offset], 0, 0 }; + } + + // if input prefix matches some user-defined token return this token as normalization result + auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); + if (user_defined_token_match.second > 0) { + return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second }; + } + + // check if the input prefix contains a valid sequence of UTF-8 code units + try { + // if yes, return this sequence unmodified + size_t prefix_offset = input_offset; + unicode_cpt_from_utf8(input, prefix_offset); + return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset }; + } catch (std::invalid_argument & /*ex*/) { + // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER + return { "\xEF\xBF\xBD", 3, 1 }; + } + } + + // escaped space symbol - U+2581 (Lower One Eighth Block) + const std::string escaped_space = "\xE2\x96\x81"; + + struct naive_trie user_defined_token_matcher; +}; + typedef enum FRAGMENT_BUFFER_VARIANT_TYPE { FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, @@ -16436,6 +16850,39 @@ static std::vector llama_tokenize_internal(const llama_vocab & "Are you sure this is what you want?\n", __FUNCTION__); } + if (add_special && vocab.tokenizer_add_eos == 1) { + GGML_ASSERT(vocab.special_eos_id != -1); + output.push_back(vocab.special_eos_id); + } + } break; + case LLAMA_VOCAB_TYPE_NTN: + { + llm_tokenizer_ntn tokenizer(vocab); + + if (add_special && vocab.tokenizer_add_bos != 0) { + GGML_ASSERT(vocab.special_bos_id != -1); + output.push_back(vocab.special_bos_id); + } + + for (const auto & fragment : fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + tokenizer.tokenize(raw_text, output); + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + output.push_back(fragment.token); + } + } + + if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) { + LLAMA_LOG_WARN( + "%s: Added a BOS token to the prompt as specified by the model but the prompt " + "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " + "Are you sure this is what you want?\n", __FUNCTION__); + } + if (add_special && vocab.tokenizer_add_eos == 1) { GGML_ASSERT(vocab.special_eos_id != -1); output.push_back(vocab.special_eos_id); @@ -19390,6 +19837,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: + case LLM_ARCH_NEMOTRON4: return LLAMA_ROPE_TYPE_NEOX; // all model arches should be listed explicitly here @@ -21101,7 +21549,7 @@ int32_t llama_tokenize( int32_t n_tokens_max, bool add_special, bool parse_special) { - auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special); + auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, false); if (n_tokens_max < (int) res.size()) { // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); return -((int) res.size()); @@ -21172,7 +21620,8 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token switch (llama_vocab_get_type(model->vocab)) { case LLAMA_VOCAB_TYPE_WPM: case LLAMA_VOCAB_TYPE_SPM: - case LLAMA_VOCAB_TYPE_UGM: { + case LLAMA_VOCAB_TYPE_UGM: + case LLAMA_VOCAB_TYPE_NTN: { // NOTE: we accept all unsupported token types, // suppressing them like CONTROL tokens. if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {