Skip to content

Commit

Permalink
Add initial support for NemotronForCausalLM.
Browse files Browse the repository at this point in the history
  • Loading branch information
sszymczy committed Jul 13, 2024
1 parent a977c11 commit 0b89395
Show file tree
Hide file tree
Showing 5 changed files with 605 additions and 6 deletions.
129 changes: 129 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3395,6 +3395,135 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
name = name.removeprefix("transformer.")
return [(self.map_tensor_name(name), data_torch)]


@Model.register("NemotronForCausalLM")
class Nemotron4Model(Model):
model_arch = gguf.MODEL_ARCH.NEMOTRON4

def set_vocab(self):
# to avoid TypeError: Descriptors cannot be created directly
# exception when importing sentencepiece_model_pb2
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
from sentencepiece import SentencePieceProcessor
from sentencepiece import sentencepiece_model_pb2 as model

tokenizer_path = self.dir_model / 'tokenizer.model'

if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}")

sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())

assert sentencepiece_model.trainer_spec.model_type == 2 # BPE

add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces

tokenizer = SentencePieceProcessor()
tokenizer.LoadFromFile(str(tokenizer_path))

vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())

tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size

for token_id in range(tokenizer.vocab_size()):
piece = tokenizer.IdToPiece(token_id)
text = piece.encode("utf-8")
score = tokenizer.GetScore(token_id)

toktype = SentencePieceTokenTypes.NORMAL
if tokenizer.IsUnknown(token_id):
toktype = SentencePieceTokenTypes.UNKNOWN
elif tokenizer.IsControl(token_id):
toktype = SentencePieceTokenTypes.CONTROL
elif tokenizer.IsUnused(token_id):
toktype = SentencePieceTokenTypes.UNUSED
elif tokenizer.IsByte(token_id):
toktype = SentencePieceTokenTypes.BYTE

tokens[token_id] = text
scores[token_id] = score
toktypes[token_id] = toktype

added_tokens_file = self.dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
added_tokens_json = json.load(f)
for key in added_tokens_json:
token_id = added_tokens_json[key]
if (token_id >= vocab_size):
logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
continue

tokens[token_id] = key.encode("utf-8")
scores[token_id] = -1000.0
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED

if vocab_size > len(tokens):
pad_count = vocab_size - len(tokens)
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
for i in range(1, pad_count + 1):
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.UNUSED)

self.gguf_writer.add_tokenizer_model("nemotron")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_add_space_prefix(add_prefix)
self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces)

special_vocab = gguf.SpecialVocab(
self.dir_model, n_vocab=len(tokens),
special_token_types = ('bos', 'eos', 'eot')
)
special_vocab._set_special_token("eot", 5) # <extra_id_1>
special_vocab.add_to_gguf(self.gguf_writer)

self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(False)

def set_gguf_parameters(self):
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(
int(self.hparams["partial_rotary_factor"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
)
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
self.gguf_writer.add_file_type(self.ftype)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if name.endswith(".layer_norm.weight") or name == "final_layernorm.weight":
logger.info(f"Adding 1.0 to {name} tensor data, see NeMo zero_centered_gamma documentation")
data_torch = data_torch + 1.0
if name.endswith(".linear_qkv.weight"):
n_head = self.find_hparam(["num_attention_heads"])
n_head_kv = self.find_hparam(["num_key_value_heads"])
head_dim = self.hparams["hidden_size"] // n_head

qkv = data_torch.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head)
k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
data_torch = torch.cat((q, k, v)).reshape_as(data_torch)


return [(self.map_tensor_name(name), data_torch)]


###### CONVERSION LOGIC ######


Expand Down
13 changes: 13 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class MODEL_ARCH(IntEnum):
BITNET = auto()
T5 = auto()
JAIS = auto()
NEMOTRON4 = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -293,6 +294,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON4: "nemotron4",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -996,6 +998,17 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.NEMOTRON4: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
# TODO
}

Expand Down
7 changes: 7 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class TensorNameMap:
"transformer.rms_norm", # Grok
"encoder.final_layernorm", # chatglm
"transformer.norm", # openelm
"final_layernorm", # nemotron4
),

# Rope frequencies
Expand Down Expand Up @@ -107,6 +108,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
"encoder.layers.{bid}.input_layernorm", # chatglm
"transformer.layers.{bid}.attn_norm", # openelm
"model.layers.{bid}.self_attention.linear_qkv.layer_norm" # nemotron4
),

# Attention norm 2
Expand All @@ -131,6 +133,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.qkv_proj", # phi3
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
"transformer.layers.{bid}.attn.qkv_proj", # openelm
"model.layers.{bid}.self_attention.linear_qkv" # nemotron4
),

# Attention query
Expand Down Expand Up @@ -190,6 +193,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
"encoder.layers.{bid}.self_attention.dense", # chatglm
"transformer.layers.{bid}.attn.out_proj", # openelm
"model.layers.{bid}.self_attention.linear_proj", # nemotron4
),

# Attention output norm
Expand Down Expand Up @@ -227,6 +231,7 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
"transformer.layers.{bid}.ffn_norm", # openelm
"model.layers.{bid}.mlp.linear_fc1.layer_norm", # nemotron4
),

# Post feed-forward norm
Expand Down Expand Up @@ -277,6 +282,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
"model.layers.{bid}.residual_mlp.w3", # arctic
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
"model.layers.{bid}.mlp.linear_fc1", # nemotron4
),

MODEL_TENSOR.FFN_UP_EXP: (
Expand Down Expand Up @@ -347,6 +353,7 @@ class TensorNameMap:
"model.layers.{bid}.residual_mlp.w2", # arctic
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
"model.layers.{bid}.mlp.linear_fc2", # nemotron4
),

MODEL_TENSOR.FFN_DOWN_EXP: (
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ extern "C" {
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
LLAMA_VOCAB_TYPE_NTN = 5, // Nemotron tokenizer based on SentencePiece BPE
};

// pre-tokenization types
Expand Down
Loading

0 comments on commit 0b89395

Please sign in to comment.