Skip to content

Commit

Permalink
add support llama tp
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed May 15, 2024
1 parent 3c2e92a commit fc37517
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 15 deletions.
70 changes: 57 additions & 13 deletions src/nanotron/mod/mod.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,75 @@
from typing import Tuple, Union

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from torchtyping import TensorType
import torch.nn.functional as F

from nanotron.parallel.pipeline_parallel.block import TensorPointer


class MixtureOfDepth(nn.Module):
def __init__(self, capacity: int, d_model: int, block: nn.Module):
super().__init__()
self.router = Router(capacity, d_model)
self.block = block

def forward(self, inputs: TensorType["batch_size", "seq_len", "d_model"]) -> TensorType["batch_size", "seq_len", "d_model"]:
selected_idxs = self.router(inputs)
assert selected_idxs.shape == (inputs.size(0), self.router.capacity)
selected_inputs = inputs[torch.arange(inputs.size(0)).unsqueeze(1), selected_idxs]

outputs_of_selected_inputs = self.block(selected_inputs)

# def forward(self, inputs: TensorType["batch_size", "seq_len", "d_model"]) -> TensorType["batch_size", "seq_len", "d_model"]:
def forward(
self,
hidden_states: Union[TensorType["batch_size", "seq_len", "d_model"], TensorPointer],
sequence_mask: Union[TensorType["batch_size", "seq_len"], TensorPointer],
) -> Tuple[
Union[TensorType["batch_size", "seq_len", "d_model"], TensorPointer],
Union[TensorType["batch_size", "seq_len"], TensorPointer],
]:
hidden_states = rearrange(hidden_states, "seq_len batch_size d_model -> batch_size seq_len d_model")
selected_idxs = self.router(hidden_states)
assert selected_idxs.shape == (hidden_states.size(0), self.router.capacity)
selected_hidden_states = hidden_states[torch.arange(hidden_states.size(0)).unsqueeze(1), selected_idxs]
selected_sequence_mask = sequence_mask[torch.arange(sequence_mask.size(0)).unsqueeze(1), selected_idxs]

selected_hidden_states = rearrange(
selected_hidden_states, "batch_size seq_len d_model -> seq_len batch_size d_model"
)
outputs_of_selected_inputs = self.block(selected_hidden_states, selected_sequence_mask)
# NOTE: now keep the representation of the selected inputs and replace the original inputs with the new ones
inputs[torch.arange(inputs.size(0)).unsqueeze(1), selected_idxs] = outputs_of_selected_inputs
return inputs

hidden_states[torch.arange(hidden_states.size(0)).unsqueeze(1), selected_idxs] = rearrange(
outputs_of_selected_inputs["hidden_states"], "seq_len batch_size d_model -> batch_size seq_len d_model"
)
hidden_states = rearrange(hidden_states, "batch_size seq_len d_model -> seq_len batch_size d_model")
return {"hidden_states": hidden_states, "sequence_mask": sequence_mask}


class Router(nn.Module):
def __init__(self, capacity: int, d_model: int):
def __init__(
self,
capacity: int,
d_model: int,
# tp_pg: dist.ProcessGroup,
# parallel_config: Optional[ParallelismArgs]
):
super().__init__()
self.capacity = capacity
self.gate = nn.Linear(d_model, 1)


# TODO(xrsrke): deduplicate this
# tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
# tp_linear_async_communication = (
# parallel_config.tp_linear_async_communication if parallel_config is not None else False
# )

# self.gate = TensorParallelRowLinear(
# d_model,
# 1,
# pg=tp_pg,
# mode=TensorParallelLinearMode.REDUCE_SCATTER,
# bias=False,
# async_communication=True,
# # contiguous_chunks=gate_up_contiguous_chunks,
# )

def forward(self, inputs: TensorType["batch_size", "seq_len", "d_model"]) -> TensorType["batch_size", "seq_len"]:
probs = F.softmax(self.gate(inputs), dim=1).view(-1, inputs.size(1))
_, top_k_indices = torch.topk(probs, self.capacity)
Expand Down
27 changes: 25 additions & 2 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""PyTorch LLaMa model."""

from typing import Dict, Optional, Union, List
from typing import Dict, Optional, Union

import torch
from torch import nn
Expand All @@ -25,6 +25,7 @@
from nanotron.config.models_config import RandomInit, SpectralMupInit
from nanotron.generation.generate_store import AttachableStore
from nanotron.logging import log_rank
from nanotron.mod.mod import MixtureOfDepth
from nanotron.models import NanotronModel
from nanotron.nn.activations import ACT2FN
from nanotron.nn.layer_norm import TritonRMSNorm
Expand All @@ -45,6 +46,15 @@

logger = logging.get_logger(__name__)

CAPACITY = 50
D_MODEL = 16


def build_mod_block(*args, **kwargs):
block = LlamaDecoderLayer(*args, **kwargs)
mod = MixtureOfDepth(CAPACITY, D_MODEL, block)
return mod


class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, end: int, theta: float = 10000.0):
Expand Down Expand Up @@ -704,11 +714,20 @@ def __init__(
module_output_keys={"input_embeds"},
)

# def build_mod_block(module_kwargs, block):
# block = self.module_builder(**self.module_kwargs)

# NOTE: how make MixtureOfDepth block wrap around these blocks?

# CAPACITY = 50
# D_MODEL = config.hidden_size

self.decoder = nn.ModuleList(
[
PipelineBlock(
p2p=self.p2p,
module_builder=LlamaDecoderLayer,
# module_builder=LlamaDecoderLayer,
module_builder=build_mod_block,
module_kwargs={
"config": config,
"parallel_config": parallel_config,
Expand Down Expand Up @@ -755,6 +774,8 @@ def __init__(
module_output_keys={"output"},
)

# self.mod_blocks = []

def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
Expand Down Expand Up @@ -795,6 +816,8 @@ def get_block_compute_costs(self):
# CausalSelfAttention (qkv proj + attn out) + MLP
LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
+ 3 * d_ff * model_config.hidden_size,
build_mod_block: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
+ 3 * d_ff * model_config.hidden_size,
# This is the last lm_head
TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
}
Expand Down

0 comments on commit fc37517

Please sign in to comment.