Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from mlc-ai:main #297

Merged
merged 1 commit into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/mlc_llm/conversation_template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# model preset templates
from . import (
cohere,
deepseek,
deepseek_v2,
dolly,
gemma,
Expand Down
21 changes: 21 additions & 0 deletions python/mlc_llm/conversation_template/deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Deepseek default templates"""

from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders

from .registry import ConvTemplateRegistry

# Deepseek
ConvTemplateRegistry.register_conv_template(
Conversation(
name="deepseek",
system_template=f"{MessagePlaceholders.SYSTEM.value}",
system_message="",
system_prefix_token_ids=[100000],
roles={"user": "User", "assistant": "Assistant"},
seps=["\n\n", "<|end▁of▁sentence|>"],
role_content_sep=": ",
role_empty_sep=":",
stop_str=["<|end▁of▁sentence|>"],
stop_token_ids=[100001],
)
)
1 change: 1 addition & 0 deletions python/mlc_llm/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,5 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
"tinyllama_v1_0",
"aya-23",
"deepseek_v2",
"deepseek",
}
Empty file.
149 changes: 149 additions & 0 deletions python/mlc_llm/model/deepseek/deepseek_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""
This file specifies how MLC's Deepseek parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""

import functools

import numpy as np

from mlc_llm.loader import ExternMapping
from mlc_llm.quantization import Quantization

from .deepseek_model import DeepseekConfig, DeepseekForCausalLM


def huggingface(model_config: DeepseekConfig, quantization: Quantization) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of HuggingFace PyTorch parameters.

Parameters
----------
model_config : MiniCPMConfig
The configuration of the MiniCPM model.

quantization : Quantization
The quantization configuration.

Returns
-------
param_map : ExternMapping
The parameter mapping from MLC to HuggingFace PyTorch.
"""
model = DeepseekForCausalLM(model_config)
if quantization is not None:
model.to(quantization.model_dtype)
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
spec=model.get_default_spec(),
allow_extern=True,
)
named_parameters = dict(_named_params)

mapping = ExternMapping()

for i in range(model_config.num_hidden_layers):
# map attention weight
attn = f"model.layers.{i}.self_attn"
for weight_type in ["weight"]:
mlc_name = f"{attn}.wqkv_pack.{weight_type}"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{attn}.q_proj.{weight_type}",
f"{attn}.k_proj.{weight_type}",
f"{attn}.v_proj.{weight_type}",
],
functools.partial(
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)

for i in range(model_config.num_hidden_layers):
if i >= model_config.first_k_dense_replace and i % model_config.moe_layer_freq == 0:
# map mlp shared expert weight
mlp = f"model.layers.{i}.mlp"
shared_expert = f"{mlp}.shared_experts"
mlc_name = f"{shared_expert}.gate_up_proj.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{shared_expert}.gate_proj.weight",
f"{shared_expert}.up_proj.weight",
],
functools.partial(
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)
# map mlp moe gate and up weight
mlc_name = f"{mlp}.moe_gate_up_proj.weight"

def combine_expert_gate_up(*hf_params, dtype):
stack = []
for i in range(0, len(hf_params), 2):
stack.append(np.concatenate([hf_params[i], hf_params[i + 1]], axis=0))
return np.stack(stack, axis=0).astype(dtype)

mapping.add_mapping(
mlc_name,
functools.reduce(
lambda a, b: a + b,
[
[
f"{mlp}.experts.{expert_id}.gate_proj.weight",
f"{mlp}.experts.{expert_id}.up_proj.weight",
]
for expert_id in range(model_config.n_routed_experts)
],
),
functools.partial(
combine_expert_gate_up,
dtype=mlc_param.dtype,
),
)

# map mlp moe gate and up weight
mlc_name = f"{mlp}.moe_down_proj.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{mlp}.experts.{expert_id}.down_proj.weight"
for expert_id in range(model_config.n_routed_experts)
],
functools.partial(
lambda *hf_params, dtype: np.stack(hf_params, axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)
else:
# map mlp weight
mlp = f"model.layers.{i}.mlp"
mlc_name = f"{mlp}.gate_up_proj.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{mlp}.gate_proj.weight",
f"{mlp}.up_proj.weight",
],
functools.partial(
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)

for mlc_name, mlc_param in named_parameters.items():
if mlc_name not in mapping.param_map:
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(
lambda x, dtype: x.astype(dtype),
dtype=mlc_param.dtype,
),
)
return mapping
Loading
Loading