Skip to content

Commit

Permalink
[Model] Add support for Nemotron architecture (mlc-ai#3069)
Browse files Browse the repository at this point in the history
This PR adds support for Nemotron architecture, and is in reference
to mlc-ai#2901 [Request for Nemotron-Mini-4B-Instruct]

Based on my analysis of the Nemotron architecture in
the huggingface repository, it appears to share similarities
with the Llama architecture, but with the following key distinctions:

- The activation function used in the MLP is `relu2` (squared ReLU).
- The MLP includes `up_proj` and `down_proj`, but does not have
a `gate_proj` as seen in Llama.
- It uses `layernorm1p`, and the normalization layer incorporates a bias term.
- The architecture employs a `partial_rotary_factor`, which is similar
to the approach used in the Phi architecture.
  • Loading branch information
hrishi121 authored Dec 19, 2024
1 parent 8a1bfd6 commit 9a33772
Show file tree
Hide file tree
Showing 8 changed files with 734 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/mlc_llm/conversation_template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
llama,
llava,
mistral,
nemotron,
oasst,
olmo,
orion,
Expand Down
27 changes: 27 additions & 0 deletions python/mlc_llm/conversation_template/nemotron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""nemotron default templates"""

from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders

from .registry import ConvTemplateRegistry

# Nemotron template
# https://huggingface.co/nvidia/Nemotron-Mini-4B-Instruct/blob/6a417790c444fd65a3da6a5c8821de6afc9654a6/tokenizer_config.json#L8030
ConvTemplateRegistry.register_conv_template(
Conversation(
name="nemotron",
system_template=("<extra_id_0>System\n" f"{MessagePlaceholders.SYSTEM.value}\n\n"),
system_message="",
roles={
"user": "<extra_id_1>User",
"assistant": "<extra_id_1>Assistant",
"tool": "<extra_id_1>Tool",
},
seps=["\n"],
role_content_sep="\n",
role_empty_sep="\n",
stop_str=["</s>"],
stop_token_ids=[3],
system_prefix_token_ids=[2],
add_role_after_system_message=True,
)
)
1 change: 1 addition & 0 deletions python/mlc_llm/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,4 +310,5 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
"deepseek_v2",
"deepseek",
"olmo",
"nemotron",
}
17 changes: 17 additions & 0 deletions python/mlc_llm/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .minicpm import minicpm_loader, minicpm_model, minicpm_quantization
from .mistral import mistral_loader, mistral_model, mistral_quantization
from .mixtral import mixtral_loader, mixtral_model, mixtral_quantization
from .nemotron import nemotron_loader, nemotron_model, nemotron_quantization
from .olmo import olmo_loader, olmo_model, olmo_quantization
from .orion import orion_loader, orion_model, orion_quantization
from .phi import phi_loader, phi_model, phi_quantization
Expand Down Expand Up @@ -565,4 +566,20 @@ class Model:
"per-tensor-quant": olmo_quantization.per_tensor_quant,
},
),
"nemotron": Model(
name="nemotron",
model=nemotron_model.NemotronForCausalLM,
config=nemotron_model.NemotronConfig,
source={
"huggingface-torch": nemotron_loader.huggingface,
"huggingface-safetensor": nemotron_loader.huggingface,
},
quantize={
"no-quant": nemotron_quantization.no_quant,
"group-quant": nemotron_quantization.group_quant,
"ft-quant": nemotron_quantization.ft_quant,
"awq": nemotron_quantization.awq_quant,
"per-tensor-quant": nemotron_quantization.per_tensor_quant,
},
),
}
Empty file.
76 changes: 76 additions & 0 deletions python/mlc_llm/model/nemotron/nemotron_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
This file specifies how MLC's Nemotron parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""

import functools

import numpy as np

from mlc_llm.loader import ExternMapping
from mlc_llm.quantization import Quantization

from .nemotron_model import NemotronConfig, NemotronForCausalLM


def huggingface(model_config: NemotronConfig, quantization: Quantization) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of HuggingFace PyTorch parameters.
Parameters
----------
model_config : NemotronConfig
The configuration of the Nemotron model.
quantization : Quantization
The quantization configuration.
Returns
-------
param_map : ExternMapping
The parameter mapping from MLC to HuggingFace PyTorch.
"""
model = NemotronForCausalLM(model_config)
if quantization is not None:
model.to(quantization.model_dtype)
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
spec=model.get_default_spec(),
allow_extern=True,
)
named_parameters = dict(_named_params)

mapping = ExternMapping()

for i in range(model_config.num_hidden_layers):
# Add QKV in self attention
attn = f"model.layers.{i}.self_attn"
mlc_name = f"{attn}.qkv_proj.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{attn}.q_proj.weight",
f"{attn}.k_proj.weight",
f"{attn}.v_proj.weight",
],
functools.partial(
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)

# inv_freq is not used in the model
mapping.add_unused(f"{attn}.rotary_emb.inv_freq")

for mlc_name, mlc_param in named_parameters.items():
if mlc_name not in mapping.param_map:
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(
lambda x, dtype: x.astype(dtype),
dtype=mlc_param.dtype,
),
)

return mapping
Loading

0 comments on commit 9a33772

Please sign in to comment.