Skip to content

Commit

Permalink
add outlier protected block (qk-norm, LayerScale, removed pre-norms, …
Browse files Browse the repository at this point in the history
…non-gated MLPs)
  • Loading branch information
bobby-he committed Oct 1, 2024
1 parent 7b7ead9 commit 1edaa95
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 26 deletions.
5 changes: 5 additions & 0 deletions src/nanotron/config/models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ class LlamaConfig:
tie_word_embeddings: bool = False
use_cache: bool = True
vocab_size: int = 32000
norm_type: str = "rmsnorm"
qknorm_type: str = "none"
resid_gain: float = 1.
use_final_norm: bool = True
use_gated_mlp: bool = False

def __post_init__(self):
# NOTE: user don't set self._init_method, ModelArgs will set it
Expand Down
112 changes: 94 additions & 18 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""PyTorch LLaMa model."""

from typing import Dict, Optional, Union, List

from collections import defaultdict
import torch
from torch import nn
from torch.utils.checkpoint import CheckpointFunction
Expand All @@ -28,7 +28,7 @@
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
from nanotron.nn.activations import ACT2FN
from nanotron.nn.layer_norm import TritonRMSNorm
from nanotron.nn.layer_norm import TritonRMSNorm, LayerScale
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
Expand Down Expand Up @@ -129,6 +129,48 @@ def forward(self, merged_states: torch.Tensor):


class MLP(nn.Module):
def __init__(
self,
config: LlamaConfig,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
):
super().__init__()

# TODO @thomasw21: refactor so that we store that default in a single place.
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)

gate_up_contiguous_chunks = (
config.intermediate_size, # shape of gate_linear
)
self.up_proj = TensorParallelColumnLinear(
config.hidden_size,
config.intermediate_size,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=gate_up_contiguous_chunks,
)
self.down_proj = TensorParallelRowLinear(
config.intermediate_size,
config.hidden_size,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
self.act = ACT2FN[config.hidden_act]

def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
merged_states = self.up_proj(hidden_states)
hidden_states = self.down_proj(self.act(merged_states))
return {"hidden_states": hidden_states}

class GatedMLP(nn.Module):
def __init__(
self,
config: LlamaConfig,
Expand Down Expand Up @@ -172,7 +214,6 @@ def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
hidden_states = self.down_proj(self.split_silu_mul(merged_states))
return {"hidden_states": hidden_states}


class CoreAttention(nn.Module):
def __init__(self, config: LlamaConfig, parallel_config: Optional[ParallelismArgs], layer_idx: int):
super().__init__()
Expand Down Expand Up @@ -334,6 +375,12 @@ def __init__(
self.prefill_kv_len = (
config.max_position_embeddings
) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings
if config.qknorm_type == "rmsnorm":
self.q_norm = TritonRMSNorm(self.d_qk, eps=config.rms_norm_eps)
self.k_norm = TritonRMSNorm(self.d_qk, eps=config.rms_norm_eps)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()

def forward(
self,
Expand Down Expand Up @@ -378,6 +425,10 @@ def forward(
.contiguous()
) # [3, batch_size, seq_length, n_local_q_heads, d_qk]

query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)


store = self.get_local_store()
if store is not None: # Inference case
# Double check that we use store only at inference time
Expand Down Expand Up @@ -583,19 +634,34 @@ def __init__(
layer_idx: int,
):
super().__init__()
self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if config.norm_type == "rmsnorm":
self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.input_layernorm = nn.Identity()
self.post_attention_layernorm = nn.Identity()

self.attn = CausalSelfAttention(
config=config,
parallel_config=parallel_config,
tp_pg=tp_pg,
layer_idx=layer_idx,
)

self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)
if config.resid_gain != 1:
self.attn_layer_scale = LayerScale(config.hidden_size, eps=config.resid_gain)
self.mlp_layer_scale = LayerScale(config.hidden_size, eps=config.resid_gain)
else:
self.attn_layer_scale = nn.Identity()
self.mlp_layer_scale = nn.Identity()

self.recompute_layer = parallel_config.recompute_layer
if config.use_gated_mlp:
self.mlp = GatedMLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)
else:
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)

self.recompute_layer = parallel_config.recompute_layer

def _core_forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
Expand All @@ -606,12 +672,12 @@ def _core_forward(

output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
hidden_states = output["hidden_states"]
hidden_states = hidden_states + residual
hidden_states = self.attn_layer_scale(hidden_states) + residual

residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
hidden_states = hidden_states + residual
hidden_states = self.mlp_layer_scale(hidden_states) + residual

return hidden_states, output["sequence_mask"]

Expand Down Expand Up @@ -719,13 +785,16 @@ def __init__(
]
)

self.final_layer_norm = PipelineBlock(
p2p=self.p2p,
module_builder=TritonRMSNorm,
module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps},
module_input_keys={"input"},
module_output_keys={"hidden_states"},
) # TODO
if config.use_final_norm:
self.final_layer_norm = PipelineBlock(
p2p=self.p2p,
module_builder=TritonRMSNorm,
module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps},
module_input_keys={"input"},
module_output_keys={"hidden_states"},
) # TODO
else:
self.final_layer_norm = lambda input: {"hidden_states": input}

self.lm_head = PipelineBlock(
p2p=self.p2p,
Expand Down Expand Up @@ -770,12 +839,19 @@ def forward_with_hidden_states(
output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask)

hidden_encoder_states = {
"hidden_states": output["input_embeds"],
"hidden_states": 50 * output["input_embeds"],
"sequence_mask": input_mask,
}
self.act_metrics = defaultdict(lambda: 0)
for encoder_block in self.decoder:
hidden_encoder_states = encoder_block(**hidden_encoder_states)

with torch.no_grad():
act_rms = (hidden_encoder_states["hidden_states"]**2).mean().sqrt()
normed_acts = hidden_encoder_states["hidden_states"] / (act_rms + 1e-8)
self.act_metrics["avg_act_rms"] += act_rms / len(self.decoder)
self.act_metrics["avg_kurt"] += (normed_acts.view(-1, normed_acts.shape[-1])**2).mean(0).var() / len(self.decoder)

hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]

sharded_logits = self.lm_head(x=hidden_states)["logits"]
Expand Down Expand Up @@ -893,7 +969,7 @@ def forward(
label_ids=label_ids,
label_mask=label_mask,
)["loss"]
return {"loss": loss}
return {"loss": loss, **self.model.act_metrics}

@torch.no_grad()
def init_model_randomly(self, config: Config):
Expand Down
14 changes: 14 additions & 0 deletions src/nanotron/nn/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,17 @@ def forward(
is_rms_norm=True,
return_dropout_mask=return_dropout_mask,
)

class LayerScale(nn.Module):
def __init__(self, hidden_size, eps=1, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.eps = eps
self.reset_parameters()

def reset_parameters(self):
nn.init.constant_(self.weight, self.eps)

def forward(self, input):
return self.weight * input
6 changes: 5 additions & 1 deletion src/nanotron/scaling/parametrization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Dict

from nanotron.config import ModelArgs
from nanotron.nn.layer_norm import TritonRMSNorm
from nanotron.nn.layer_norm import TritonRMSNorm, LayerScale
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
Expand Down Expand Up @@ -38,6 +38,7 @@ def __init__(self, config: ModelArgs):
TensorParallelRowLinear: self._parametrize_row_linear,
TritonRMSNorm: self._parametrize_layer_norm,
TensorParallelEmbedding: self._parametrize_embedding,
LayerScale: self._parametrize_layer_scale
}

self.std = config.init_method.std
Expand Down Expand Up @@ -69,6 +70,9 @@ def _parametrize_layer_norm(self, param_name: str, module: nn.Module):
elif "bias" == param_name:
module.bias.zero_()

def _parametrize_layer_scale(self, param_name: str, module: nn.Module):
assert param_name in ["weight"]

def _parametrize_embedding(self, param_name: str, module: nn.Module):
assert param_name in ["weight"]

Expand Down
46 changes: 39 additions & 7 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def train(
self._update_dataloader_based_on_training_stages(dataloader_or_dls)

# Training step
outputs, loss_avg = self.training_step(dataloader=self.current_dataloader)
outputs, loss_avg, kurt_avg, act_rms_avg = self.training_step(dataloader=self.current_dataloader)

# Training Logs
# TODO(xrsrke): refactor using callbacks would be better
Expand All @@ -437,7 +437,7 @@ def train(
].consumed_train_samples += self.global_batch_size

if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0:
self.train_step_logs(outputs=outputs, loss_avg=loss_avg)
self.train_step_logs(outputs=outputs, loss_avg=loss_avg, kurt_avg=kurt_avg, act_rms_avg=act_rms_avg)

# Checkpoint
if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0:
Expand Down Expand Up @@ -523,10 +523,34 @@ def training_step(
[output["loss"] for output in outputs]
).sum() # already divided by n_micro_batches_per_batch
# sync loss across DP
handle = dist.all_reduce(loss_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG)
loss_handle = dist.all_reduce(loss_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG)
else:
loss_avg = None
handle = None
loss_handle = None

# Compute DP average kurt and overlap with optimizer step
if isinstance(outputs[0]["avg_kurt"], torch.Tensor):
# This is an average on only one data rank.
kurt_avg = torch.stack(
[output["avg_kurt"] for output in outputs]
).sum() # already divided by n_micro_batches_per_batch
# sync kurt across DP
kurt_handle = dist.all_reduce(kurt_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG)
else:
kurt_avg = None
kurt_handle = None

# Compute DP average act_rms and overlap with optimizer step
if isinstance(outputs[0]["avg_act_rms"], torch.Tensor):
# This is an average on only one data rank.
act_rms_avg = torch.stack(
[output["avg_act_rms"] for output in outputs]
).sum() # already divided by n_micro_batches_per_batch
# sync act_rms across DP
act_rms_handle = dist.all_reduce(act_rms_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG)
else:
act_rms_avg = None
act_rms_handle = None

# Apply gradient
self.optimizer.step()
Expand All @@ -537,12 +561,16 @@ def training_step(

after_optim_step_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator)

if handle is not None:
handle.wait()
if loss_handle is not None:
loss_handle.wait()
if kurt_handle is not None:
kurt_handle.wait()
if act_rms_handle is not None:
act_rms_handle.wait()

self.post_train_step()

return outputs, loss_avg
return outputs, loss_avg, kurt_avg, act_rms_avg

def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]:
outputs = self.pipeline_engine.validate_batch_iter(
Expand All @@ -556,6 +584,8 @@ def train_step_logs(
self,
outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
loss_avg: Optional[torch.Tensor],
kurt_avg: Optional[torch.Tensor],
act_rms_avg: Optional[torch.Tensor],
) -> None:
# TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607
dist.barrier()
Expand Down Expand Up @@ -589,6 +619,8 @@ def train_step_logs(
), # , "1.6E"),
LogItem("global_batch_size", self.global_batch_size, "human_format"), # , "5d"),
LogItem("lm_loss", loss_avg.item(), "human_format"), # , "1.6E"),
LogItem("kurt_avg", kurt_avg.item(), "human_format"), # , "1.6E"),
LogItem("act_rms_avg", act_rms_avg.item(), "human_format"), # , "1.6E"),
LogItem("lr", lr, "human_format"), # , ".3E"),
LogItem("model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"),
LogItem("hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"),
Expand Down

0 comments on commit 1edaa95

Please sign in to comment.