Skip to content

Commit

Permalink
[Model] Support multi-GPU for Deepseek-v2 (mlc-ai#3080)
Browse files Browse the repository at this point in the history
This PR supports tensor parallelism for Deepseek-v2 model.
  • Loading branch information
MasterJH5574 authored Jan 4, 2025
1 parent 6faf68e commit 8243b2b
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 18 deletions.
4 changes: 2 additions & 2 deletions python/mlc_llm/model/deepseek/deepseek_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,10 @@ def _set(layer, hint):
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
out = self.input_layernorm(hidden_states)
out = self.self_attn(out, paged_kv_cache, layer_id)
hidden_states = self._apply_residual(hidden_states, residual=out)
hidden_states = self._apply_residual(out, residual=hidden_states)
out = self.post_attention_layernorm(hidden_states)
out = self.mlp(out) # type: ignore[operator]
hidden_states = self._apply_residual(hidden_states, residual=out)
hidden_states = self._apply_residual(out, residual=hidden_states)
return hidden_states

def _apply_residual(self, out, residual):
Expand Down
109 changes: 93 additions & 16 deletions python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.nn.expert import MixtralExperts
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
from mlc_llm.support.style import bold

Expand Down Expand Up @@ -79,20 +80,17 @@ def __post_init__(self):
logger.info(
"%s defaults to %d",
bold("prefill_chunk_size"),
min(self.context_window_size, 8192),
min(self.context_window_size, 2048),
)
self.prefill_chunk_size = min(self.context_window_size, 8192)
self.prefill_chunk_size = min(self.context_window_size, 2048)
elif self.prefill_chunk_size > self.context_window_size:
logger.info(
"Overriding %s from %d to %d",
bold("prefill_chunk_size"),
self.prefill_chunk_size,
min(self.context_window_size, 8192),
min(self.context_window_size, 2048),
)
self.prefill_chunk_size = min(self.context_window_size, 8192)

if self.tensor_parallel_shards != 1:
raise ValueError("Only support single device at this moment.")
self.prefill_chunk_size = min(self.context_window_size, 2048)


# pylint: disable=invalid-name,missing-docstring,too-many-locals
Expand All @@ -102,9 +100,15 @@ class DeepseekV2MLP(nn.Module):
def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None):
super().__init__()
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
self.intermediate_size = (
intermediate_size = (
config.intermediate_size if intermediate_size is None else intermediate_size
)
if intermediate_size % config.tensor_parallel_shards != 0:
raise ValueError(
f"Cannot split MoE intermediate size {intermediate_size} "
f"evenly to {config.tensor_parallel_shards} GPUs."
)
self.intermediate_size = intermediate_size // config.tensor_parallel_shards

self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
Expand Down Expand Up @@ -173,7 +177,12 @@ def __init__(self, config: DeepseekV2Config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
if config.num_attention_heads % config.tensor_parallel_shards != 0:
raise ValueError(
f"Cannot split {config.num_attention_heads} attention heads "
f"evenly to {config.tensor_parallel_shards} GPUs."
)
self.num_heads = config.num_attention_heads // config.tensor_parallel_shards

self.rope_theta = config.rope_theta
self.q_lora_rank = config.q_lora_rank
Expand Down Expand Up @@ -320,7 +329,12 @@ def __init__(self, config: DeepseekV2Config):

self.gate = nn.Linear(config.hidden_size, self.num_routed_experts, bias=False)
self.norm_topk_prob = config.norm_topk_prob
self.moe_intermediate_size = config.moe_intermediate_size
if config.moe_intermediate_size % config.tensor_parallel_shards != 0:
raise ValueError(
f"Cannot split MoE intermediate size {config.moe_intermediate_size} "
f"evenly to {config.tensor_parallel_shards} GPUs."
)
self.moe_intermediate_size = config.moe_intermediate_size // config.tensor_parallel_shards

self.moe_gate_up_proj = MixtralExperts(
self.num_routed_experts,
Expand All @@ -333,8 +347,9 @@ def __init__(self, config: DeepseekV2Config):
out_features=config.hidden_size,
)

intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP(config, intermediate_size=intermediate_size)
self.shared_experts = DeepseekV2MLP(
config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
)

def forward(self, x: Tensor):
def _expert_forward(x: Tensor, indptr: Tensor):
Expand Down Expand Up @@ -404,15 +419,72 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int):
config.hidden_size, -1, config.rms_norm_eps, bias=False
)

def _set_tp():
def _set(layer, hint):
layer.attrs["shard_strategy"] = hint

if self.self_attn.q_lora_rank is None:
_set(
self.self_attn.q_proj.weight,
tp.ShardSingleDim("_shard_q_weight", dim=0),
)
else:
_set(
self.self_attn.q_b_proj.weight,
tp.ShardSingleDim("_shard_q_b_weight", dim=0),
)

_set(
self.self_attn.kv_b_proj.weight,
tp.ShardSingleDim("_shard_kv_b_weight", dim=0),
)
_set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))

if isinstance(self.mlp, DeepseekV2MoE):
si = self.mlp.shared_experts.intermediate_size
mi = self.mlp.moe_intermediate_size
_set(
self.mlp.shared_experts.gate_up_proj.weight,
tp.ShardSingleDim("_shard_shared_experts_gate_up", segs=[si, si], dim=0),
)
_set(
self.mlp.shared_experts.down_proj.weight,
tp.ShardSingleDim("_shard_shared_experts_down", dim=1),
)
_set(
self.mlp.moe_gate_up_proj.weight,
tp.ShardSingleDim("_shard_moe_gate_up", segs=[mi, mi], dim=1),
)
_set(self.mlp.moe_down_proj.weight, tp.ShardSingleDim("_shard_moe_mlp_down", dim=2))
else:
assert isinstance(self.mlp, DeepseekV2MLP)
si = self.mlp.intermediate_size
_set(
self.mlp.gate_up_proj.weight,
tp.ShardSingleDim("_shard_gate_up", segs=[si, si], dim=0),
)
_set(
self.mlp.down_proj.weight,
tp.ShardSingleDim("_shard_down", dim=1),
)

self.tensor_parallel_shards = config.tensor_parallel_shards
_set_tp()

def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
out = self.input_layernorm(hidden_states)
out = self.self_attn(out, paged_kv_cache, layer_id)
hidden_states = hidden_states + out
hidden_states = self._apply_residual(out, residual=hidden_states)
out = self.post_attention_layernorm(hidden_states)
out = self.mlp(out) # type: ignore[operator]
hidden_states = hidden_states + out
hidden_states = self._apply_residual(out, residual=hidden_states)
return hidden_states

def _apply_residual(self, out, residual):
if self.tensor_parallel_shards > 1:
return op.ccl_allreduce(out, "sum") + residual
return out + residual


class DeepseekV2Model(nn.Module):
def __init__(self, config: DeepseekV2Config):
Expand Down Expand Up @@ -446,6 +518,7 @@ def __init__(self, config: DeepseekV2Config):
self.rms_norm_eps = config.rms_norm_eps
self.rope_theta = config.rope_theta
self.vocab_size = config.vocab_size
self.tensor_parallel_shards = config.tensor_parallel_shards

def to(self, dtype: Optional[str] = None):
super().to(dtype=dtype)
Expand All @@ -469,6 +542,8 @@ def batch_forward(
return logits

def embed(self, input_ids: Tensor):
if self.tensor_parallel_shards > 1:
input_ids = op.ccl_broadcast_from_worker0(input_ids)
return self.model.embed_tokens(input_ids)

def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
Expand Down Expand Up @@ -497,6 +572,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
def batch_prefill(
self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache
):
if self.tensor_parallel_shards > 1:
logit_positions = op.ccl_broadcast_from_worker0(logit_positions)
logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)
return logits, paged_kv_cache

Expand All @@ -523,8 +600,8 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size=page_size,
support_sliding_window=support_sliding_window,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,
num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,
head_dim=256,
rope_mode=RopeMode.NONE,
rope_scale=1,
Expand Down

0 comments on commit 8243b2b

Please sign in to comment.