diff --git a/python/mlc_llm/model/deepseek/deepseek_model.py b/python/mlc_llm/model/deepseek/deepseek_model.py index 77883ca124..cdab6e1935 100644 --- a/python/mlc_llm/model/deepseek/deepseek_model.py +++ b/python/mlc_llm/model/deepseek/deepseek_model.py @@ -307,10 +307,10 @@ def _set(layer, hint): def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): out = self.input_layernorm(hidden_states) out = self.self_attn(out, paged_kv_cache, layer_id) - hidden_states = self._apply_residual(hidden_states, residual=out) + hidden_states = self._apply_residual(out, residual=hidden_states) out = self.post_attention_layernorm(hidden_states) out = self.mlp(out) # type: ignore[operator] - hidden_states = self._apply_residual(hidden_states, residual=out) + hidden_states = self._apply_residual(out, residual=hidden_states) return hidden_states def _apply_residual(self, out, residual): diff --git a/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py b/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py index 82dbfcae72..c2cecc3621 100644 --- a/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py +++ b/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py @@ -16,6 +16,7 @@ from mlc_llm.nn import PagedKVCache, RopeMode from mlc_llm.nn.expert import MixtralExperts from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase from mlc_llm.support.style import bold @@ -79,20 +80,17 @@ def __post_init__(self): logger.info( "%s defaults to %d", bold("prefill_chunk_size"), - min(self.context_window_size, 8192), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = min(self.context_window_size, 8192) + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - min(self.context_window_size, 8192), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = min(self.context_window_size, 8192) - - if self.tensor_parallel_shards != 1: - raise ValueError("Only support single device at this moment.") + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring,too-many-locals @@ -102,9 +100,15 @@ class DeepseekV2MLP(nn.Module): def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None): super().__init__() self.hidden_size = config.hidden_size if hidden_size is None else hidden_size - self.intermediate_size = ( + intermediate_size = ( config.intermediate_size if intermediate_size is None else intermediate_size ) + if intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MoE intermediate size {intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) + self.intermediate_size = intermediate_size // config.tensor_parallel_shards self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) @@ -173,7 +177,12 @@ def __init__(self, config: DeepseekV2Config): super().__init__() self.config = config self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads + if config.num_attention_heads % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split {config.num_attention_heads} attention heads " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) + self.num_heads = config.num_attention_heads // config.tensor_parallel_shards self.rope_theta = config.rope_theta self.q_lora_rank = config.q_lora_rank @@ -320,7 +329,12 @@ def __init__(self, config: DeepseekV2Config): self.gate = nn.Linear(config.hidden_size, self.num_routed_experts, bias=False) self.norm_topk_prob = config.norm_topk_prob - self.moe_intermediate_size = config.moe_intermediate_size + if config.moe_intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MoE intermediate size {config.moe_intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) + self.moe_intermediate_size = config.moe_intermediate_size // config.tensor_parallel_shards self.moe_gate_up_proj = MixtralExperts( self.num_routed_experts, @@ -333,8 +347,9 @@ def __init__(self, config: DeepseekV2Config): out_features=config.hidden_size, ) - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV2MLP(config, intermediate_size=intermediate_size) + self.shared_experts = DeepseekV2MLP( + config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) def forward(self, x: Tensor): def _expert_forward(x: Tensor, indptr: Tensor): @@ -404,15 +419,72 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int): config.hidden_size, -1, config.rms_norm_eps, bias=False ) + def _set_tp(): + def _set(layer, hint): + layer.attrs["shard_strategy"] = hint + + if self.self_attn.q_lora_rank is None: + _set( + self.self_attn.q_proj.weight, + tp.ShardSingleDim("_shard_q_weight", dim=0), + ) + else: + _set( + self.self_attn.q_b_proj.weight, + tp.ShardSingleDim("_shard_q_b_weight", dim=0), + ) + + _set( + self.self_attn.kv_b_proj.weight, + tp.ShardSingleDim("_shard_kv_b_weight", dim=0), + ) + _set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1)) + + if isinstance(self.mlp, DeepseekV2MoE): + si = self.mlp.shared_experts.intermediate_size + mi = self.mlp.moe_intermediate_size + _set( + self.mlp.shared_experts.gate_up_proj.weight, + tp.ShardSingleDim("_shard_shared_experts_gate_up", segs=[si, si], dim=0), + ) + _set( + self.mlp.shared_experts.down_proj.weight, + tp.ShardSingleDim("_shard_shared_experts_down", dim=1), + ) + _set( + self.mlp.moe_gate_up_proj.weight, + tp.ShardSingleDim("_shard_moe_gate_up", segs=[mi, mi], dim=1), + ) + _set(self.mlp.moe_down_proj.weight, tp.ShardSingleDim("_shard_moe_mlp_down", dim=2)) + else: + assert isinstance(self.mlp, DeepseekV2MLP) + si = self.mlp.intermediate_size + _set( + self.mlp.gate_up_proj.weight, + tp.ShardSingleDim("_shard_gate_up", segs=[si, si], dim=0), + ) + _set( + self.mlp.down_proj.weight, + tp.ShardSingleDim("_shard_down", dim=1), + ) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): out = self.input_layernorm(hidden_states) out = self.self_attn(out, paged_kv_cache, layer_id) - hidden_states = hidden_states + out + hidden_states = self._apply_residual(out, residual=hidden_states) out = self.post_attention_layernorm(hidden_states) out = self.mlp(out) # type: ignore[operator] - hidden_states = hidden_states + out + hidden_states = self._apply_residual(out, residual=hidden_states) return hidden_states + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + class DeepseekV2Model(nn.Module): def __init__(self, config: DeepseekV2Config): @@ -446,6 +518,7 @@ def __init__(self, config: DeepseekV2Config): self.rms_norm_eps = config.rms_norm_eps self.rope_theta = config.rope_theta self.vocab_size = config.vocab_size + self.tensor_parallel_shards = config.tensor_parallel_shards def to(self, dtype: Optional[str] = None): super().to(dtype=dtype) @@ -469,6 +542,8 @@ def batch_forward( return logits def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) return self.model.embed_tokens(input_ids) def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): @@ -497,6 +572,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache @@ -523,8 +600,8 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size=page_size, support_sliding_window=support_sliding_window, num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, head_dim=256, rope_mode=RopeMode.NONE, rope_scale=1,