diff --git a/.gitignore b/.gitignore index cbc04eaf..677e1298 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,4 @@ cython_debug/ checkpoints/ wandb/ +slurm-* \ No newline at end of file diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 5c92e1e7..58f8b32d 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -29,7 +29,7 @@ from nanotron.logging import log_rank from nanotron.models import NanotronModel from nanotron.nn.activations import ACT2FN -from nanotron.nn.layer_norm import TritonRMSNorm +from nanotron.nn.layer_norm import RMSNorm from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer @@ -130,173 +130,64 @@ def forward( return x_out.type(dtype) -## Copy from transformers. Non interleaved version of RoPE. Will be refactored later def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) - -def _compute_default_rope_parameters( - config: Optional[LlamaConfig] = None, - device: Optional["torch.device"] = None, - seq_len: Optional[int] = None, - **rope_kwargs, -) -> Tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies according to the original RoPE implementation - Args: - config ([`~transformers.PretrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). - """ - if config is not None and len(rope_kwargs) > 0: - raise ValueError( - "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " - f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" - ) - if len(rope_kwargs) > 0: - base = rope_kwargs["base"] - dim = rope_kwargs["dim"] - elif config is not None: - base = config.rope_theta - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) - - attention_factor = 1.0 # Unused in this type of RoPE - - # Compute the inverse frequencies - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) - return inv_freq, attention_factor - -def _compute_llama3_parameters( - config: LlamaConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs -) -> Tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies for llama 3.1. - - Args: - config ([`~transformers.PretrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin. - """ - # Gets the default RoPE parameters - inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) - - factor = config.rope_scaling["factor"] # `8` in the original implementation - low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation - high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation - old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation - +def apply_scaling(freqs: torch.Tensor): + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor - - wavelen = 2 * math.pi / inv_freq - # wavelen < high_freq_wavelen: do nothing - # wavelen > low_freq_wavelen: divide by factor - inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) - # otherwise: interpolate between the two, using a smooth factor - smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) - smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama - is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) - inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) - - return inv_freq_llama, attention_factor - -ROPE_INIT_FUNCTIONS = { - "default": _compute_default_rope_parameters, - "llama3": _compute_llama3_parameters, -} + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) class LlamaRotaryEmbedding(nn.Module): - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[LlamaConfig] = None, - ): + def __init__(self, dim: int, end: int, theta: float = 500000.0): super().__init__() - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.45" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings - else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len + self.dim = dim + self.end = end + self.theta = theta + self.init_rotary_embeddings() - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len + def init_rotary_embeddings(self): + inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu") / self.dim) + ) # important to compute on CPU + inv_freq = apply_scaling(inv_freq) + self.register_buffer( + "inv_freq", torch.empty(self.dim // 2, dtype=torch.float, device="cuda"), persistent=False + ) + self.inv_freq = self.inv_freq.to( + torch.float + ) # make it float32 before copy to avoid precision loss during copy_ + self.inv_freq.copy_(inv_freq) @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block + def forward( + self, + x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk] + position_ids: Optional[torch.LongTensor], # [batch_size, seq_length] + ): + # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -304,17 +195,11 @@ def forward(self, x, position_ids): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2): """Applies Rotary Position Embedding to the query and key tensors. - Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. @@ -336,71 +221,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed -## Copy from transformers. Non interleaved version of RoPE. Will be refactored later -# def rotate_half(x): -# """Rotates half the hidden dims of the input.""" -# x1 = x[..., : x.shape[-1] // 2] -# x2 = x[..., x.shape[-1] // 2 :] -# return torch.cat((-x2, x1), dim=-1) - - -# class LlamaRotaryEmbedding(nn.Module): -# def __init__(self, dim: int, end: int, theta: float = 500000.0): -# super().__init__() -# self.dim = dim -# self.end = end -# self.theta = theta -# self.init_rotary_embeddings() - -# def init_rotary_embeddings(self): -# inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda") / self.dim)) -# self.register_buffer("inv_freq", inv_freq, persistent=False) - -# @torch.no_grad() -# def forward( -# self, -# x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk] -# position_ids: Optional[torch.LongTensor], # [batch_size, seq_length] -# ): -# # x: [bs, num_attention_heads, seq_len, head_size] -# # print("rotary") -# inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) -# position_ids_expanded = position_ids[:, None, :].float() -# # Force float32 since bfloat16 loses precision on long contexts -# # See https://github.com/huggingface/transformers/pull/29285 -# device_type = x.device.type -# device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" -# with torch.autocast(device_type=device_type, enabled=False): -# freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) -# emb = torch.cat((freqs, freqs), dim=-1) -# cos = emb.cos() -# sin = emb.sin() -# return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -# def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2): -# """Applies Rotary Position Embedding to the query and key tensors. -# Args: -# q (`torch.Tensor`): The query tensor. -# k (`torch.Tensor`): The key tensor. -# cos (`torch.Tensor`): The cosine part of the rotary embedding. -# sin (`torch.Tensor`): The sine part of the rotary embedding. -# unsqueeze_dim (`int`, *optional*, defaults to 1): -# The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and -# sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note -# that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and -# k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes -# cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have -# the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. -# Returns: -# `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. -# """ -# cos = cos.unsqueeze(unsqueeze_dim) -# sin = sin.unsqueeze(unsqueeze_dim) -# q_embed = (q * cos) + (rotate_half(q) * sin) -# k_embed = (k * cos) + (rotate_half(k) * sin) -# return q_embed, k_embed - class GLUActivation(nn.Module): def __init__(self, act_fn_name: str): super().__init__() @@ -639,11 +459,11 @@ def __init__( else: self.rotary_embedding = LlamaRotaryEmbedding( dim=self.d_qk, - max_position_embeddings=config.max_position_embeddings, - #end=config.max_position_embeddings, - #theta=config.rope_theta, - base=config.rope_theta, - config=config, + end=config.max_position_embeddings, + theta=config.rope_theta, + #max_position_embeddings=config.max_position_embeddings, + #base=config.rope_theta, + #config=config, ) self.rope_interleaved = config.rope_interleaved @@ -955,7 +775,7 @@ def __init__( layer_idx: int, ): super().__init__() - self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attn = CausalSelfAttention( config=config, parallel_config=parallel_config, @@ -965,7 +785,7 @@ def __init__( ) self.layer_idx = layer_idx - self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) self.recompute_layer = parallel_config.recompute_layer @@ -1103,7 +923,7 @@ def __init__( self.final_layer_norm = PipelineBlock( p2p=self.p2p, - module_builder=TritonRMSNorm, + module_builder=RMSNorm, module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, module_input_keys={"input"}, module_output_keys={"hidden_states"}, diff --git a/src/nanotron/nn/layer_norm.py b/src/nanotron/nn/layer_norm.py index 688eaa78..ddf81337 100644 --- a/src/nanotron/nn/layer_norm.py +++ b/src/nanotron/nn/layer_norm.py @@ -51,3 +51,20 @@ def forward( is_rms_norm=True, return_dropout_mask=return_dropout_mask, ) + +# equivalent to TritonRMSNorm +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, input): + input_dtype = input.dtype + input = input.to(torch.float32) + variance = input.pow(2).mean(-1, keepdim=True) + input = input * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * input.to(input_dtype) \ No newline at end of file diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index e6241651..5ad8f769 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -4,7 +4,7 @@ from typing import Dict from nanotron.config import ModelArgs -from nanotron.nn.layer_norm import TritonRMSNorm +from nanotron.nn.layer_norm import RMSNorm, TritonRMSNorm from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -37,6 +37,7 @@ def __init__(self, config: ModelArgs): TensorParallelColumnLinear: self._parametrize_column_linear, TensorParallelRowLinear: self._parametrize_row_linear, TritonRMSNorm: self._parametrize_layer_norm, + RMSNorm: self._parametrize_layer_norm, TensorParallelEmbedding: self._parametrize_embedding, } @@ -88,6 +89,7 @@ def __init__(self, config: ModelArgs): TensorParallelColumnLinear: self._parametrize_mup_weight, TensorParallelRowLinear: self._parametrize_mup_weight, TritonRMSNorm: self._parametrize_layer_norm, + RMSNorm: self._parametrize_layer_norm, TensorParallelEmbedding: self._parametrize_embedding, } self.std = 1.0