Skip to content

Commit

Permalink
[SLM] GPTJ Multi-GPU support (mlc-ai#3070)
Browse files Browse the repository at this point in the history
This PR supports TP function of GPTJ Model and fix minor typo of OlMo Model.
  • Loading branch information
tlopex authored Dec 19, 2024
1 parent 9a33772 commit 1825fed
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
36 changes: 33 additions & 3 deletions python/mlc_llm/model/gpt_j/gpt_j_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
from mlc_llm.support.style import bold

Expand Down Expand Up @@ -57,6 +58,9 @@ def __post_init__(self):
"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is "
"provided in `config.json`."
)
if self.head_dim == 0:
self.head_dim = self.n_embd // self.n_head
assert self.head_dim * self.n_head == self.n_embd
if self.prefill_chunk_size == 0:
logger.info(
"%s defaults to %d",
Expand All @@ -72,7 +76,6 @@ def __post_init__(self):
min(self.context_window_size, 8192),
)
self.prefill_chunk_size = min(self.context_window_size, 8192)
assert self.tensor_parallel_shards == 1, "GPTJ currently does not support sharding."


# pylint: disable=invalid-name,missing-docstring
Expand All @@ -82,7 +85,7 @@ class GPTJAttention(nn.Module): # pylint: disable=too-many-instance-attributes
def __init__(self, config: GPTJConfig):
self.embed_dim = config.n_embd
self.num_heads = config.n_head // config.tensor_parallel_shards
self.head_dim = self.embed_dim // self.num_heads
self.head_dim = config.head_dim
self.max_position_embeddings = config.context_window_size
self.rope_theta = 10000
self.rotary_dim = config.rotary_dim
Expand Down Expand Up @@ -140,14 +143,41 @@ def __init__(self, config: GPTJConfig):
self.attn = GPTJAttention(config)
self.mlp = GPTJMLP(config)

def _set_tp():
def _set(layer, hint):
layer.attrs["shard_strategy"] = hint

hd = config.head_dim
q = self.attn.num_heads * hd
k = self.attn.num_heads * hd
v = self.attn.num_heads * hd
_set(
self.attn.c_attn.weight,
tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]),
)
_set(self.attn.out_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))
_set(
self.mlp.fc_in.weight,
tp.ShardSingleDim("_shard_c_fc_weight", dim=0),
)
_set(self.mlp.fc_out.weight, tp.ShardSingleDim("_shard_mlp_c_proj", dim=1))

self.tensor_parallel_shards = config.tensor_parallel_shards
_set_tp()

def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(hidden_states, paged_kv_cache, layer_id)
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = attn_output + feed_forward_hidden_states + residual
hidden_states = self._apply_residual(attn_output + feed_forward_hidden_states, residual)
return hidden_states

def _apply_residual(self, out, residual):
if self.tensor_parallel_shards > 1:
return op.ccl_allreduce(out, "sum") + residual
return out + residual


class GPTJModel(nn.Module):
def __init__(self, config: GPTJConfig):
Expand Down
4 changes: 2 additions & 2 deletions python/mlc_llm/model/olmo/olmo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __post_init__(self): # pylint: disable=too-many-branches
raise ValueError(f"'clip_qkv'({self.clip_qkv}) should be non-negative")


class OLMoEebedding(nn.Embedding):
class OLMoEmbedding(nn.Embedding):
"""The embedding module that can be shared with the final lm_head. From Qwen2Embedding."""

def lm_head_forward(self, x: nn.Tensor):
Expand Down Expand Up @@ -248,7 +248,7 @@ def forward( # pylint: disable=missing-function-docstring
class OLMoModel(nn.Module): # pylint: disable=missing-class-docstring
def __init__(self, config: OLMoConfig):
assert config.hidden_size % config.num_attention_heads == 0
self.embed_tokens = OLMoEebedding(config.vocab_size, config.hidden_size)
self.embed_tokens = OLMoEmbedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[OLMoDecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
Expand Down

0 comments on commit 1825fed

Please sign in to comment.