diff --git a/python/mlc_llm/model/gpt_j/gpt_j_model.py b/python/mlc_llm/model/gpt_j/gpt_j_model.py index b72c938f8a..b5f199319e 100644 --- a/python/mlc_llm/model/gpt_j/gpt_j_model.py +++ b/python/mlc_llm/model/gpt_j/gpt_j_model.py @@ -13,6 +13,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase from mlc_llm.support.style import bold @@ -57,6 +58,9 @@ def __post_init__(self): "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) + if self.head_dim == 0: + self.head_dim = self.n_embd // self.n_head + assert self.head_dim * self.n_head == self.n_embd if self.prefill_chunk_size == 0: logger.info( "%s defaults to %d", @@ -72,7 +76,6 @@ def __post_init__(self): min(self.context_window_size, 8192), ) self.prefill_chunk_size = min(self.context_window_size, 8192) - assert self.tensor_parallel_shards == 1, "GPTJ currently does not support sharding." # pylint: disable=invalid-name,missing-docstring @@ -82,7 +85,7 @@ class GPTJAttention(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: GPTJConfig): self.embed_dim = config.n_embd self.num_heads = config.n_head // config.tensor_parallel_shards - self.head_dim = self.embed_dim // self.num_heads + self.head_dim = config.head_dim self.max_position_embeddings = config.context_window_size self.rope_theta = 10000 self.rotary_dim = config.rotary_dim @@ -140,14 +143,41 @@ def __init__(self, config: GPTJConfig): self.attn = GPTJAttention(config) self.mlp = GPTJMLP(config) + def _set_tp(): + def _set(layer, hint): + layer.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.attn.num_heads * hd + k = self.attn.num_heads * hd + v = self.attn.num_heads * hd + _set( + self.attn.c_attn.weight, + tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]), + ) + _set(self.attn.out_proj.weight, tp.ShardSingleDim("_shard_o", dim=1)) + _set( + self.mlp.fc_in.weight, + tp.ShardSingleDim("_shard_c_fc_weight", dim=0), + ) + _set(self.mlp.fc_out.weight, tp.ShardSingleDim("_shard_mlp_c_proj", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_output = self.attn(hidden_states, paged_kv_cache, layer_id) feed_forward_hidden_states = self.mlp(hidden_states) - hidden_states = attn_output + feed_forward_hidden_states + residual + hidden_states = self._apply_residual(attn_output + feed_forward_hidden_states, residual) return hidden_states + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + class GPTJModel(nn.Module): def __init__(self, config: GPTJConfig): diff --git a/python/mlc_llm/model/olmo/olmo_model.py b/python/mlc_llm/model/olmo/olmo_model.py index 06f00a0b6a..e4ecb998db 100644 --- a/python/mlc_llm/model/olmo/olmo_model.py +++ b/python/mlc_llm/model/olmo/olmo_model.py @@ -102,7 +102,7 @@ def __post_init__(self): # pylint: disable=too-many-branches raise ValueError(f"'clip_qkv'({self.clip_qkv}) should be non-negative") -class OLMoEebedding(nn.Embedding): +class OLMoEmbedding(nn.Embedding): """The embedding module that can be shared with the final lm_head. From Qwen2Embedding.""" def lm_head_forward(self, x: nn.Tensor): @@ -248,7 +248,7 @@ def forward( # pylint: disable=missing-function-docstring class OLMoModel(nn.Module): # pylint: disable=missing-class-docstring def __init__(self, config: OLMoConfig): assert config.hidden_size % config.num_attention_heads == 0 - self.embed_tokens = OLMoEebedding(config.vocab_size, config.hidden_size) + self.embed_tokens = OLMoEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList( [OLMoDecoderLayer(config) for _ in range(config.num_hidden_layers)] )