Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Mixture of Depths #171

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/nanotron/mod/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from nanotron.mod.mod import MixtureOfDepth, Router
146 changes: 146 additions & 0 deletions src/nanotron/mod/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from typing import Dict, Optional, Union, List

import torch
from torch import nn

import torch.distributed as dist
from nanotron.config import LlamaConfig, ParallelismArgs
from nanotron.nn.layer_norm import TritonRMSNorm
from nanotron.parallel import ParallelContext
from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
from nanotron.parallel.pipeline_parallel.p2p import P2P
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelLinearMode,
)
from nanotron.models.llama import LlamaModel, Embedding, LlamaDecoderLayer, CausalSelfAttention, MLP
from nanotron.mod.mod import MixtureOfDepth, Router


# class LlamaDecoderLayer(nn.Module):
# def __init__(
# self,
# config: LlamaConfig,
# parallel_config: Optional[ParallelismArgs],
# tp_pg: dist.ProcessGroup,
# layer_idx: int,
# ):
# super().__init__()
# self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.attn = CausalSelfAttention(
# config=config,
# parallel_config=parallel_config,
# tp_pg=tp_pg,
# layer_idx=layer_idx,
# )

# self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)
# self.router = Router(seq_len=1024, top_k=10)

# def forward(
# self,
# hidden_states: Union[torch.Tensor, TensorPointer],
# sequence_mask: Union[torch.Tensor, TensorPointer],
# ) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# residual = hidden_states
# hidden_states = self.input_layernorm(hidden_states)

# output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
# hidden_states = output["hidden_states"]
# hidden_states = hidden_states + residual

# residual = hidden_states
# hidden_states = self.post_attention_layernorm(hidden_states)
# hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
# hidden_states = hidden_states + residual

# return {
# "hidden_states": hidden_states,
# "sequence_mask": output["sequence_mask"],
# }


class MoDLlamaModel(nn.Module, LlamaModel):
"""Build pipeline graph"""

def __init__(
self,
config: LlamaConfig,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
):
super().__init__()

# Declare all the nodes
self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda"))
self.config = config
self.parallel_config = parallel_config
self.parallel_context = parallel_context
self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)

self.token_position_embeddings = PipelineBlock(
p2p=self.p2p,
module_builder=Embedding,
module_kwargs={
"tp_pg": parallel_context.tp_pg,
"config": config,
"parallel_config": parallel_config,
},
module_input_keys={"input_ids", "input_mask"},
module_output_keys={"input_embeds"},
)

self.decoder = nn.ModuleList(
[
PipelineBlock(
p2p=self.p2p,
module_builder=LlamaDecoderLayer,
module_kwargs={
"config": config,
"parallel_config": parallel_config,
"tp_pg": parallel_context.tp_pg,
"layer_idx": layer_idx,
},
module_input_keys={"hidden_states", "sequence_mask"},
module_output_keys={"hidden_states", "sequence_mask"},
)
for layer_idx in range(config.num_hidden_layers)
]
)

self.final_layer_norm = PipelineBlock(
p2p=self.p2p,
module_builder=TritonRMSNorm,
module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps},
module_input_keys={"input"},
module_output_keys={"hidden_states"},
)

self.lm_head = PipelineBlock(
p2p=self.p2p,
# Understand that this means that we return sharded logits that are going to need to be gathered
module_builder=TensorParallelColumnLinear,
module_kwargs={
"in_features": config.hidden_size,
"out_features": config.vocab_size,
"pg": parallel_context.tp_pg,
"bias": False,
# TODO @thomasw21: refactor so that we store that default in a single place.
"mode": self.tp_mode,
"async_communication": tp_linear_async_communication,
},
module_input_keys={"x"},
module_output_keys={"logits"},
)

self.cast_to_fp32 = PipelineBlock(
p2p=self.p2p,
module_builder=lambda: lambda x: x.float(),
module_kwargs={},
module_input_keys={"x"},
module_output_keys={"output"},
)
76 changes: 76 additions & 0 deletions src/nanotron/mod/mod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Tuple, Union

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from torchtyping import TensorType

from nanotron.parallel.pipeline_parallel.block import TensorPointer


class MixtureOfDepth(nn.Module):
def __init__(self, capacity: int, d_model: int, block: nn.Module):
super().__init__()
self.router = Router(capacity, d_model)
self.block = block

# def forward(self, inputs: TensorType["batch_size", "seq_len", "d_model"]) -> TensorType["batch_size", "seq_len", "d_model"]:
def forward(
self,
hidden_states: Union[TensorType["batch_size", "seq_len", "d_model"], TensorPointer],
sequence_mask: Union[TensorType["batch_size", "seq_len"], TensorPointer],
) -> Tuple[
Union[TensorType["batch_size", "seq_len", "d_model"], TensorPointer],
Union[TensorType["batch_size", "seq_len"], TensorPointer],
]:
hidden_states = rearrange(hidden_states, "seq_len batch_size d_model -> batch_size seq_len d_model")
selected_idxs = self.router(hidden_states)
assert selected_idxs.shape == (hidden_states.size(0), self.router.capacity)
selected_hidden_states = hidden_states[torch.arange(hidden_states.size(0)).unsqueeze(1), selected_idxs]
selected_sequence_mask = sequence_mask[torch.arange(sequence_mask.size(0)).unsqueeze(1), selected_idxs]

selected_hidden_states = rearrange(
selected_hidden_states, "batch_size seq_len d_model -> seq_len batch_size d_model"
)
outputs_of_selected_inputs = self.block(selected_hidden_states, selected_sequence_mask)
# NOTE: now keep the representation of the selected inputs and replace the original inputs with the new ones
hidden_states[torch.arange(hidden_states.size(0)).unsqueeze(1), selected_idxs] = rearrange(
outputs_of_selected_inputs["hidden_states"], "seq_len batch_size d_model -> batch_size seq_len d_model"
)
hidden_states = rearrange(hidden_states, "batch_size seq_len d_model -> seq_len batch_size d_model")
return {"hidden_states": hidden_states, "sequence_mask": sequence_mask}


class Router(nn.Module):
def __init__(
self,
capacity: int,
d_model: int,
# tp_pg: dist.ProcessGroup,
# parallel_config: Optional[ParallelismArgs]
):
super().__init__()
self.capacity = capacity
self.gate = nn.Linear(d_model, 1)

# TODO(xrsrke): deduplicate this
# tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
# tp_linear_async_communication = (
# parallel_config.tp_linear_async_communication if parallel_config is not None else False
# )

# self.gate = TensorParallelRowLinear(
# d_model,
# 1,
# pg=tp_pg,
# mode=TensorParallelLinearMode.REDUCE_SCATTER,
# bias=False,
# async_communication=True,
# # contiguous_chunks=gate_up_contiguous_chunks,
# )

def forward(self, inputs: TensorType["batch_size", "seq_len", "d_model"]) -> TensorType["batch_size", "seq_len"]:
probs = F.softmax(self.gate(inputs), dim=1).view(-1, inputs.size(1))
_, top_k_indices = torch.topk(probs, self.capacity)
return top_k_indices
27 changes: 25 additions & 2 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""PyTorch LLaMa model."""

from typing import Dict, Optional, Union, List
from typing import Dict, Optional, Union

import torch
from torch import nn
Expand All @@ -25,6 +25,7 @@
from nanotron.config.models_config import RandomInit, SpectralMupInit
from nanotron.generation.generate_store import AttachableStore
from nanotron.logging import log_rank
from nanotron.mod.mod import MixtureOfDepth
from nanotron.models import NanotronModel
from nanotron.nn.activations import ACT2FN
from nanotron.nn.layer_norm import TritonRMSNorm
Expand All @@ -45,6 +46,15 @@

logger = logging.get_logger(__name__)

CAPACITY = 50
D_MODEL = 16


def build_mod_block(*args, **kwargs):
block = LlamaDecoderLayer(*args, **kwargs)
mod = MixtureOfDepth(CAPACITY, D_MODEL, block)
return mod


class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, end: int, theta: float = 10000.0):
Expand Down Expand Up @@ -704,11 +714,20 @@ def __init__(
module_output_keys={"input_embeds"},
)

# def build_mod_block(module_kwargs, block):
# block = self.module_builder(**self.module_kwargs)

# NOTE: how make MixtureOfDepth block wrap around these blocks?

# CAPACITY = 50
# D_MODEL = config.hidden_size

self.decoder = nn.ModuleList(
[
PipelineBlock(
p2p=self.p2p,
module_builder=LlamaDecoderLayer,
# module_builder=LlamaDecoderLayer,
module_builder=build_mod_block,
module_kwargs={
"config": config,
"parallel_config": parallel_config,
Expand Down Expand Up @@ -755,6 +774,8 @@ def __init__(
module_output_keys={"output"},
)

# self.mod_blocks = []

def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
Expand Down Expand Up @@ -795,6 +816,8 @@ def get_block_compute_costs(self):
# CausalSelfAttention (qkv proj + attn out) + MLP
LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
+ 3 * d_ff * model_config.hidden_size,
build_mod_block: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
+ 3 * d_ff * model_config.hidden_size,
# This is the last lm_head
TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
}
Expand Down
36 changes: 36 additions & 0 deletions tests/test_mod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from torch import nn
import pytest

from nanotron.mod import MixtureOfDepth, Router


@pytest.mark.parametrize("seq_len, top_k", [(1, 1), (10, 5), (10, 10)])
def test_mod(seq_len, top_k):
BATCH_SIZE = 15
D_MODEL = 1024

linear = nn.Linear(D_MODEL, D_MODEL)
block = MixtureOfDepth(top_k, D_MODEL, linear)

inputs = torch.randn(BATCH_SIZE, seq_len, D_MODEL)
ref_inputs = inputs.clone()
outputs = block(inputs)

expected_num_tokens_not_changed = (seq_len - top_k) * BATCH_SIZE
num_tokens_not_changed = torch.eq(outputs.view(-1, D_MODEL), ref_inputs.view(-1, D_MODEL)).all(dim=1).sum().item()

assert outputs.shape == linear(ref_inputs).shape
assert num_tokens_not_changed == expected_num_tokens_not_changed, f"num_tokens_not_changed: {num_tokens_not_changed}, expected: {expected_num_tokens_not_changed}"


@pytest.mark.parametrize("capacity, d_model", [(1, 64), (10, 64)])
def test_router(capacity, d_model):
BATCH_SIZE, SEQ_LEN = 5, 10
inputs = torch.randn(BATCH_SIZE, SEQ_LEN, d_model)

router = Router(capacity, d_model)
selected_idxs = router(inputs)

assert selected_idxs.shape == (BATCH_SIZE, capacity)
assert selected_idxs.dtype == torch.int64
Loading