From 1edaa957c387b029fead54d2623bd41e17afbd67 Mon Sep 17 00:00:00 2001 From: bobby-he Date: Tue, 1 Oct 2024 18:28:41 +0200 Subject: [PATCH] add outlier protected block (qk-norm, LayerScale, removed pre-norms, non-gated MLPs) --- src/nanotron/config/models_config.py | 5 ++ src/nanotron/models/llama.py | 112 ++++++++++++++++++++---- src/nanotron/nn/layer_norm.py | 14 +++ src/nanotron/scaling/parametrization.py | 6 +- src/nanotron/trainer.py | 46 ++++++++-- 5 files changed, 157 insertions(+), 26 deletions(-) diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 2630e1d6..e831cbee 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -54,6 +54,11 @@ class LlamaConfig: tie_word_embeddings: bool = False use_cache: bool = True vocab_size: int = 32000 + norm_type: str = "rmsnorm" + qknorm_type: str = "none" + resid_gain: float = 1. + use_final_norm: bool = True + use_gated_mlp: bool = False def __post_init__(self): # NOTE: user don't set self._init_method, ModelArgs will set it diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 28a2e30f..17e0da49 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -15,7 +15,7 @@ """PyTorch LLaMa model.""" from typing import Dict, Optional, Union, List - +from collections import defaultdict import torch from torch import nn from torch.utils.checkpoint import CheckpointFunction @@ -28,7 +28,7 @@ from nanotron.logging import log_rank from nanotron.models import NanotronModel from nanotron.nn.activations import ACT2FN -from nanotron.nn.layer_norm import TritonRMSNorm +from nanotron.nn.layer_norm import TritonRMSNorm, LayerScale from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer @@ -129,6 +129,48 @@ def forward(self, merged_states: torch.Tensor): class MLP(nn.Module): + def __init__( + self, + config: LlamaConfig, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + super().__init__() + + # TODO @thomasw21: refactor so that we store that default in a single place. + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + gate_up_contiguous_chunks = ( + config.intermediate_size, # shape of gate_linear + ) + self.up_proj = TensorParallelColumnLinear( + config.hidden_size, + config.intermediate_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + contiguous_chunks=gate_up_contiguous_chunks, + ) + self.down_proj = TensorParallelRowLinear( + config.intermediate_size, + config.hidden_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + ) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] + merged_states = self.up_proj(hidden_states) + hidden_states = self.down_proj(self.act(merged_states)) + return {"hidden_states": hidden_states} + +class GatedMLP(nn.Module): def __init__( self, config: LlamaConfig, @@ -172,7 +214,6 @@ def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] hidden_states = self.down_proj(self.split_silu_mul(merged_states)) return {"hidden_states": hidden_states} - class CoreAttention(nn.Module): def __init__(self, config: LlamaConfig, parallel_config: Optional[ParallelismArgs], layer_idx: int): super().__init__() @@ -334,6 +375,12 @@ def __init__( self.prefill_kv_len = ( config.max_position_embeddings ) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings + if config.qknorm_type == "rmsnorm": + self.q_norm = TritonRMSNorm(self.d_qk, eps=config.rms_norm_eps) + self.k_norm = TritonRMSNorm(self.d_qk, eps=config.rms_norm_eps) + else: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() def forward( self, @@ -378,6 +425,10 @@ def forward( .contiguous() ) # [3, batch_size, seq_length, n_local_q_heads, d_qk] + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + store = self.get_local_store() if store is not None: # Inference case # Double check that we use store only at inference time @@ -583,7 +634,13 @@ def __init__( layer_idx: int, ): super().__init__() - self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if config.norm_type == "rmsnorm": + self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.input_layernorm = nn.Identity() + self.post_attention_layernorm = nn.Identity() + self.attn = CausalSelfAttention( config=config, parallel_config=parallel_config, @@ -591,11 +648,20 @@ def __init__( layer_idx=layer_idx, ) - self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + if config.resid_gain != 1: + self.attn_layer_scale = LayerScale(config.hidden_size, eps=config.resid_gain) + self.mlp_layer_scale = LayerScale(config.hidden_size, eps=config.resid_gain) + else: + self.attn_layer_scale = nn.Identity() + self.mlp_layer_scale = nn.Identity() - self.recompute_layer = parallel_config.recompute_layer + if config.use_gated_mlp: + self.mlp = GatedMLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + else: + self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + self.recompute_layer = parallel_config.recompute_layer + def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], @@ -606,12 +672,12 @@ def _core_forward( output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) hidden_states = output["hidden_states"] - hidden_states = hidden_states + residual + hidden_states = self.attn_layer_scale(hidden_states) + residual residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] - hidden_states = hidden_states + residual + hidden_states = self.mlp_layer_scale(hidden_states) + residual return hidden_states, output["sequence_mask"] @@ -719,13 +785,16 @@ def __init__( ] ) - self.final_layer_norm = PipelineBlock( - p2p=self.p2p, - module_builder=TritonRMSNorm, - module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, - module_input_keys={"input"}, - module_output_keys={"hidden_states"}, - ) # TODO + if config.use_final_norm: + self.final_layer_norm = PipelineBlock( + p2p=self.p2p, + module_builder=TritonRMSNorm, + module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, + module_input_keys={"input"}, + module_output_keys={"hidden_states"}, + ) # TODO + else: + self.final_layer_norm = lambda input: {"hidden_states": input} self.lm_head = PipelineBlock( p2p=self.p2p, @@ -770,12 +839,19 @@ def forward_with_hidden_states( output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) hidden_encoder_states = { - "hidden_states": output["input_embeds"], + "hidden_states": 50 * output["input_embeds"], "sequence_mask": input_mask, } + self.act_metrics = defaultdict(lambda: 0) for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) + with torch.no_grad(): + act_rms = (hidden_encoder_states["hidden_states"]**2).mean().sqrt() + normed_acts = hidden_encoder_states["hidden_states"] / (act_rms + 1e-8) + self.act_metrics["avg_act_rms"] += act_rms / len(self.decoder) + self.act_metrics["avg_kurt"] += (normed_acts.view(-1, normed_acts.shape[-1])**2).mean(0).var() / len(self.decoder) + hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] sharded_logits = self.lm_head(x=hidden_states)["logits"] @@ -893,7 +969,7 @@ def forward( label_ids=label_ids, label_mask=label_mask, )["loss"] - return {"loss": loss} + return {"loss": loss, **self.model.act_metrics} @torch.no_grad() def init_model_randomly(self, config: Config): diff --git a/src/nanotron/nn/layer_norm.py b/src/nanotron/nn/layer_norm.py index 688eaa78..f5469fe5 100644 --- a/src/nanotron/nn/layer_norm.py +++ b/src/nanotron/nn/layer_norm.py @@ -51,3 +51,17 @@ def forward( is_rms_norm=True, return_dropout_mask=return_dropout_mask, ) + +class LayerScale(nn.Module): + def __init__(self, hidden_size, eps=1, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.eps = eps + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self.weight, self.eps) + + def forward(self, input): + return self.weight * input diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index e6241651..becbb9e8 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -4,7 +4,7 @@ from typing import Dict from nanotron.config import ModelArgs -from nanotron.nn.layer_norm import TritonRMSNorm +from nanotron.nn.layer_norm import TritonRMSNorm, LayerScale from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -38,6 +38,7 @@ def __init__(self, config: ModelArgs): TensorParallelRowLinear: self._parametrize_row_linear, TritonRMSNorm: self._parametrize_layer_norm, TensorParallelEmbedding: self._parametrize_embedding, + LayerScale: self._parametrize_layer_scale } self.std = config.init_method.std @@ -69,6 +70,9 @@ def _parametrize_layer_norm(self, param_name: str, module: nn.Module): elif "bias" == param_name: module.bias.zero_() + def _parametrize_layer_scale(self, param_name: str, module: nn.Module): + assert param_name in ["weight"] + def _parametrize_embedding(self, param_name: str, module: nn.Module): assert param_name in ["weight"] diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 21251a32..ebb674e7 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -426,7 +426,7 @@ def train( self._update_dataloader_based_on_training_stages(dataloader_or_dls) # Training step - outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) + outputs, loss_avg, kurt_avg, act_rms_avg = self.training_step(dataloader=self.current_dataloader) # Training Logs # TODO(xrsrke): refactor using callbacks would be better @@ -437,7 +437,7 @@ def train( ].consumed_train_samples += self.global_batch_size if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0: - self.train_step_logs(outputs=outputs, loss_avg=loss_avg) + self.train_step_logs(outputs=outputs, loss_avg=loss_avg, kurt_avg=kurt_avg, act_rms_avg=act_rms_avg) # Checkpoint if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0: @@ -523,10 +523,34 @@ def training_step( [output["loss"] for output in outputs] ).sum() # already divided by n_micro_batches_per_batch # sync loss across DP - handle = dist.all_reduce(loss_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG) + loss_handle = dist.all_reduce(loss_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG) else: loss_avg = None - handle = None + loss_handle = None + + # Compute DP average kurt and overlap with optimizer step + if isinstance(outputs[0]["avg_kurt"], torch.Tensor): + # This is an average on only one data rank. + kurt_avg = torch.stack( + [output["avg_kurt"] for output in outputs] + ).sum() # already divided by n_micro_batches_per_batch + # sync kurt across DP + kurt_handle = dist.all_reduce(kurt_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG) + else: + kurt_avg = None + kurt_handle = None + + # Compute DP average act_rms and overlap with optimizer step + if isinstance(outputs[0]["avg_act_rms"], torch.Tensor): + # This is an average on only one data rank. + act_rms_avg = torch.stack( + [output["avg_act_rms"] for output in outputs] + ).sum() # already divided by n_micro_batches_per_batch + # sync act_rms across DP + act_rms_handle = dist.all_reduce(act_rms_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG) + else: + act_rms_avg = None + act_rms_handle = None # Apply gradient self.optimizer.step() @@ -537,12 +561,16 @@ def training_step( after_optim_step_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator) - if handle is not None: - handle.wait() + if loss_handle is not None: + loss_handle.wait() + if kurt_handle is not None: + kurt_handle.wait() + if act_rms_handle is not None: + act_rms_handle.wait() self.post_train_step() - return outputs, loss_avg + return outputs, loss_avg, kurt_avg, act_rms_avg def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: outputs = self.pipeline_engine.validate_batch_iter( @@ -556,6 +584,8 @@ def train_step_logs( self, outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], loss_avg: Optional[torch.Tensor], + kurt_avg: Optional[torch.Tensor], + act_rms_avg: Optional[torch.Tensor], ) -> None: # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 dist.barrier() @@ -589,6 +619,8 @@ def train_step_logs( ), # , "1.6E"), LogItem("global_batch_size", self.global_batch_size, "human_format"), # , "5d"), LogItem("lm_loss", loss_avg.item(), "human_format"), # , "1.6E"), + LogItem("kurt_avg", kurt_avg.item(), "human_format"), # , "1.6E"), + LogItem("act_rms_avg", act_rms_avg.item(), "human_format"), # , "1.6E"), LogItem("lr", lr, "human_format"), # , ".3E"), LogItem("model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"), LogItem("hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"),