diff --git a/examples/config_llama3.2-3B.yaml b/examples/config_llama3.2-3B.yaml new file mode 100644 index 00000000..3128671b --- /dev/null +++ b/examples/config_llama3.2-3B.yaml @@ -0,0 +1,100 @@ +checkpoints: + checkpoint_interval: 20 + checkpoints_path: checkpoints/nanotron_training/Nanotron-Llama-3.2-3B-FT + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: checkpoints/nanotron_pretrained_checkpoints/Nanotron-Llama-3.2-3B + save_initial_state: false + load_lr_scheduler: false + load_optimizer: false +data_stages: +- data: + dataset: + dataset_folder: /store/swissai/a06/datasets_tokenized/nanotron/Meta-Llama-3-8B/fineweb-edu-sample-100BT + num_loading_workers: 1 + seed: 42 + name: General purpose training (Single dataset) + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: Llama3.2 + run: llama + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 128 + model_config: + bos_token_id: 128000 + eos_token_id: 128001 + hidden_act: silu + hidden_size: 3072 + initializer_range: 0.02 + intermediate_size: 8192 + is_llama_config: true + max_position_embeddings: 8192 + num_attention_heads: 24 + num_hidden_layers: 28 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: + factor: 32.0 + high_freq_factor: 4.0 + low_freq_factor: 1.0 + original_max_position_embeddings: 8192 + rope_type: llama3 + rope_theta: 500000.0 + tie_word_embeddings: true + use_cache: true + vocab_size: 128256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 98 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 2 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 2 + tp_linear_async_communication: false + tp_mode: ALL_REDUCE +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: meta-llama/Llama-3.2-3B + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 4 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 2 + sequence_length: 8192 + train_steps: 200 + val_check_interval: -1 diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index c50334f6..92c91bdd 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -160,6 +160,8 @@ class CheckpointsArgs: save_final_state: Optional[bool] = False resume_checkpoint_path: Optional[xPath] = None checkpoints_path_is_shared_file_system: Optional[bool] = False + load_lr_scheduler: Optional[bool] = True + load_optimizer: Optional[bool] = True def __post_init__(self): if isinstance(self.checkpoints_path, str): diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index d92de405..c6769004 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -50,7 +50,7 @@ class LlamaConfig: rope_theta: float = 10000.0 rope_interleaved: bool = ( False # The default value has been True, but for loading Llama3 checkpoints you have to set it to False - ) + ) # TODO(tj.solergibert) Depreciate this arg tie_word_embeddings: bool = False use_cache: bool = True vocab_size: int = 32000 diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 88fb6bcb..bb811880 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,8 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, List, Optional, Union +import math +from typing import Dict, List, Optional, Union, Tuple import torch from torch import nn @@ -46,99 +47,83 @@ logger = logging.get_logger(__name__) +# NOTE(tj.solergibert) Copied from: https://github.com/huggingface/transformers/blob/2b053fdf1a638de17faa8791d96efac5e2507be7/src/transformers/modeling_rope_utils.py#L29 +def _compute_default_rope_parameters(config: LlamaConfig) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~nanotron.config.LlamaConfig`]): + The nanotron training configuration. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_theta + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) -class RotaryEmbedding(nn.Module): - def __init__(self, dim: int, end: int, theta: float = 10000.0): - super().__init__() - assert dim % 2 == 0 - self.dim = dim - self.end = end - self.theta = theta - # TODO @nouamane: Figure out why we can't set `DTypeInvariantTensor` ... - # TODO @thomasw21: Complex buffers break DDP, instead we store float and view them as complex - self.freqs_cis: torch.Tensor - self._initialized_buffer = False - - def init_rotary_embeddings(self): - if self._initialized_buffer is True: - # Buffer if already initialized - return - self.register_buffer( - "freqs_cis", - torch.empty(self.end, self.dim // 2, 2, dtype=torch.float, device="cuda"), - persistent=False, - ) - assert self.freqs_cis.device.type == "cuda" - # TODO @nouamane: One we figure out how to do the DTypeInvariantTensor, this can be removed and changed to an assert - if self.freqs_cis.dtype != torch.float: - self.freqs_cis = self.freqs_cis.to(torch.float) - assert self.freqs_cis.dtype == torch.float - freqs = 1.0 / ( - self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu")[: (self.dim // 2)] / self.dim) - ).to( - "cuda" - ) # should be computed on CPU, otherwise different results with Transformers. - t = torch.arange(self.end, device="cuda") - freqs = torch.outer(t, freqs).float() - complex_freqs = torch.polar(torch.ones_like(freqs), freqs) - freqs = torch.view_as_real(complex_freqs) - self.freqs_cis.copy_(freqs) - self._initialized_buffer = True + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float() / head_dim)) + return inv_freq - def forward( - self, - x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk] - position_ids: Optional[torch.LongTensor], # [batch_size, seq_length] - ): - batch_size, seq_length, num_heads, inner_dim = x.shape - while ( - position_ids is not None and position_ids[-1, -1] >= self.end - ) or seq_length >= self.end: # TODO @nouamane: check if this causes cpu-gpu sync - self.end *= 2 - self._initialized_buffer = False - if self._initialized_buffer is False: - print(f"Initializing rotary embeddings with end={self.end}") - self.init_rotary_embeddings() - dtype = x.dtype - assert inner_dim % 2 == 0 - x = x.view( - batch_size, seq_length, num_heads, inner_dim // 2, 2 - ) # [batch_size, q_length, num_heads, inner_dim] - if x.dtype == torch.bfloat16: - x = x.float() - complex_x = torch.view_as_complex(x) # [batch_size, q_length, num_heads, inner_dim // 2] - if position_ids is None: - freqs_cis = self.freqs_cis[None, :seq_length, None, :] - else: - # TODO(kunhao): Should None follow the num_heads dimension? - if position_ids[-1, -1] < 0 or position_ids[-1, -1] >= self.end: # Quick test hopefully - raise ValueError(f"Position ids must be in the range [0, {self.end}), but got {position_ids}") - freqs_cis = self.freqs_cis[position_ids][:, :, None, :] - complex_freqs = torch.view_as_complex(freqs_cis) - x_out = torch.view_as_real(complex_x * complex_freqs).view(batch_size, seq_length, num_heads, inner_dim) - return x_out.type(dtype) +# NOTE(tj.solergibert) Copied from: https://github.com/huggingface/transformers/blob/2b053fdf1a638de17faa8791d96efac5e2507be7/src/transformers/modeling_rope_utils.py#L310 +def _compute_llama3_parameters(config: LlamaConfig) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies for llama 3.1. + Args: + config ([`~nanotron.config.LlamaConfig`]): + The nanotron training configuration. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # Gets the default RoPE parameters + inv_freq = _compute_default_rope_parameters(config) + + factor = config.rope_scaling["factor"] # `8` in the original implementation + low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation + high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation + old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + + return inv_freq_llama + +# NOTE(tj.solergibert) Copied from: https://github.com/huggingface/transformers/blob/2b053fdf1a638de17faa8791d96efac5e2507be7/src/transformers/modeling_rope_utils.py#L353-L363 +ROPE_INIT_FUNCTIONS = { + "default": _compute_default_rope_parameters, + "llama3": _compute_llama3_parameters, +} -## Copy from transformers. Non interleaved version of RoPE. Will be refactored later class LlamaRotaryEmbedding(nn.Module): - def __init__(self, dim: int, end: int, theta: float = 500000.0): + def __init__( + self, + config: LlamaConfig + ): super().__init__() - self.dim = dim - self.end = end - self.theta = theta - self.init_rotary_embeddings() - - def init_rotary_embeddings(self): - inv_freq = 1.0 / ( - self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu") / self.dim) - ) # important to compute on CPU - self.register_buffer( - "inv_freq", torch.empty(self.dim // 2, dtype=torch.float, device="cuda"), persistent=False - ) - self.inv_freq = self.inv_freq.to( - torch.float - ) # make it float32 before copy to avoid precision loss during copy_ - self.inv_freq.copy_(inv_freq) + self.rope_type = config.rope_scaling.get("rope_type", "default") + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq = self.rope_init_fn(self.config) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + self.end = config.max_position_embeddings # NOTE(tj.solergibert) To support inference @torch.no_grad() def forward( @@ -149,47 +134,45 @@ def forward( # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) - def rotate_half(self, x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=2): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (self.rotate_half(q) * sin) - k_embed = (k * cos) + (self.rotate_half(k) * sin) - return q_embed, k_embed +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed class GLUActivation(nn.Module): def __init__(self, act_fn_name: str): @@ -262,39 +245,21 @@ def __init__(self, config: LlamaConfig, parallel_config: Optional[ParallelismArg @checkpoint_method(attr_name="checkpoint_attention") def forward( self, - query_states: torch.Tensor, # [batch_size * q_length, n_local_q_heads, inner_dim] - key_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim] - value_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim] - q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size) - kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size) + query_states: torch.Tensor, # [batch_size, q_length, n_local_q_heads, inner_dim] + key_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] + value_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] ): - from flash_attn.flash_attn_interface import flash_attn_varlen_func - - # TODO @thomasw21: Compute once, instead of computing for each layers. - cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:]) - torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:]) - - # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not - # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache. - causal = False if q_sequence_mask.shape[1] == 1 else True + from flash_attn.flash_attn_interface import flash_attn_func # NOTE: this scale is for µTransfer, # in SP, we use sqrt(1/d_h) softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None - attn_output = flash_attn_varlen_func( + attn_output = flash_attn_func( q=query_states, k=key_states, v=value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=q_sequence_mask.shape[1], - max_seqlen_k=kv_sequence_mask.shape[1], - dropout_p=0.0, softmax_scale=softmax_scale, - causal=causal, - return_attn_probs=False, + causal=True, ) return attn_output @@ -391,25 +356,8 @@ def __init__( contiguous_chunks=qkv_contiguous_chunks, tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) - # TODO(kunhao): We want to have only one version per device and not one version per layer. - if config.rope_interleaved: - self.rotary_embedding = RotaryEmbedding( - dim=self.d_qk, - end=config.max_position_embeddings, - theta=config.rope_theta, - ) - else: - self.rotary_embedding = LlamaRotaryEmbedding( - dim=self.d_qk, - end=config.max_position_embeddings, - theta=config.rope_theta, - ) - self.rope_interleaved = config.rope_interleaved - - # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet) - self.flash_rotary_embedding = FlashRotaryEmbedding( - dim=self.d_qk, base=config.rope_theta, interleaved=config.rope_interleaved - ) + + self.rotary_emb = LlamaRotaryEmbedding(config=config) self.o_proj = TensorParallelRowLinear( config.num_attention_heads * self.d_qk, @@ -487,15 +435,16 @@ def forward( # Compute rotary embeddings # Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache - old_rotary_embed_end = self.rotary_embedding.end + old_rotary_embed_end = self.rotary_emb.end # interleaved version. + # TODO(tj.solergibert) This might not work if self.rope_interleaved: - query_states = self.rotary_embedding(query_states, position_ids=position_ids) - key_states = self.rotary_embedding(key_states, position_ids=position_ids) + query_states = self.rotary_emb(query_states, position_ids=position_ids) + key_states = self.rotary_emb(key_states, position_ids=position_ids) # non interleaved version. else: - cos, sin = self.rotary_embedding(value_states, position_ids) - query_states, key_states = self.rotary_embedding.apply_rotary_pos_emb( + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin ) @@ -567,14 +516,14 @@ def forward( # NOTE(fmom): According to flash_attn_with_kvcache, "If you pass in k / v, you must make sure that the cache is large enough to hold the new values" # Since rotary embedding has changed (to enable larger context), we need to enlarge k_cache and v_cache - if self.rotary_embedding.end > old_rotary_embed_end: + if self.rotary_emb.end > old_rotary_embed_end: k_cache = torch.cat( [ k_cache, torch.zeros( ( batch_size, - self.rotary_embedding.end - old_rotary_embed_end, + self.rotary_emb.end - old_rotary_embed_end, self.n_local_kv_heads, self.d_qk, ), @@ -591,7 +540,7 @@ def forward( torch.zeros( ( batch_size, - self.rotary_embedding.end - old_rotary_embed_end, + self.rotary_emb.end - old_rotary_embed_end, self.n_local_kv_heads, self.d_v, ), @@ -603,11 +552,11 @@ def forward( ) assert ( - k_cache.shape[1] == self.rotary_embedding.end - ), f"Cache size {k_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}" + k_cache.shape[1] == self.rotary_emb.end + ), f"Cache size {k_cache.shape[1]} is smaller than rotary embedding end {self.rotary_emb.end}" assert ( - v_cache.shape[1] == self.rotary_embedding.end - ), f"Cache size {v_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}" + v_cache.shape[1] == self.rotary_emb.end + ), f"Cache size {v_cache.shape[1]} is smaller than rotary embedding end {self.rotary_emb.end}" # [batch_size, seq_length, num_heads, d_qk] query_states = query_states.view( @@ -649,39 +598,19 @@ def forward( else: # Training case # Apply rotary embeddings to query/key states - # NOTE: The layout is different from models/llama.py which is [batch_size, num_heads, seq_length, d_qk] - # Here it is, [batch_size, seq_length, num_heads, d_qk] - # [2, batch_size, seq_length, num_heads, d_qk] - key_value_states = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0) - # [batch_size, seq_length, 2, num_heads, d_qk] - key_value_states = key_value_states.permute(1, 2, 0, 3, 4).contiguous() - query_states, key_value_states = self.flash_rotary_embedding(query_states, kv=key_value_states) - # [batch_size, seq_length, num_heads, d_qk] - key_states, value_states = torch.split(key_value_states, 1, dim=2) - - q_sequence_mask = sequence_mask - kv_sequence_mask = sequence_mask - - kv_length = key_states.shape[1] - # [batch_size, seq_length, num_heads, d_qk] - # Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func` - query_states = query_states.view( - batch_size * q_length, self.n_local_q_heads, self.d_qk - ) # [batch_size * q_length, self.n_heads, d_qk] - - key_states = key_states.view( - batch_size * kv_length, self.n_local_kv_heads, self.d_qk - ) # [batch_size * kv_length, self.n_heads, d_qk] - value_states = value_states.view( - batch_size * kv_length, self.n_local_kv_heads, self.d_v - ) # [batch_size * kv_length, self.n_heads, d_v] - + # TODO(tj.solergibert) Re-check to reduce number of transposes + position_ids = torch.cumsum(sequence_mask, dim=-1, dtype=torch.int32) - 1 + cos, sin = self.rotary_emb(value_states, position_ids) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + attention_output = self.attention( query_states=query_states, key_states=key_states, value_states=value_states, - q_sequence_mask=q_sequence_mask, - kv_sequence_mask=kv_sequence_mask, ) attention_output = ( diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index e9ad04d8..b1445b48 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -105,6 +105,7 @@ def save( save_lr_scheduler( lr_scheduler=lr_scheduler, + is_zero=config.optimizer.zero_stage, parallel_context=parallel_context, root_folder=root_folder, ) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 4ed830de..7585d520 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -57,7 +57,7 @@ ) from nanotron.models import NanotronModel, build_model from nanotron.models.base import check_model_has_grad -from nanotron.models.llama import LlamaForTraining, RotaryEmbedding +from nanotron.models.llama import LlamaForTraining from nanotron.models.starcoder2 import Starcoder2ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext @@ -190,7 +190,7 @@ def __init__( optimizer_args=self.config.optimizer, parallel_context=self.parallel_context, ) - if self.init_checkpoint_path is not None: + if self.init_checkpoint_path is not None and self.config.checkpoints.load_optimizer: load_optimizer( optimizer=self.optimizer, parallel_context=self.parallel_context, @@ -206,7 +206,7 @@ def __init__( lr_scheduler_args=self.config.optimizer.learning_rate_scheduler, total_training_steps=self.config.tokens.train_steps, ) - if self.init_checkpoint_path is not None: + if self.init_checkpoint_path is not None and self.config.checkpoints.load_lr_scheduler: load_lr_scheduler( lr_scheduler=self.lr_scheduler, is_zero=self.config.optimizer.zero_stage, @@ -215,7 +215,7 @@ def __init__( ) # Define iteration start state - if self.init_checkpoint_path is not None: + if self.init_checkpoint_path is not None and self.config.checkpoints.load_optimizer: checkpoint_metadata = load_meta( parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path ) @@ -553,7 +553,7 @@ def training_step( handle = None # Move optimizer states back to GPU before optimizer step - if self.init_checkpoint_path is not None and self.iteration_step == self.initial_iter_step: + if self.init_checkpoint_path is not None and self.iteration_step == self.initial_iter_step and self.config.checkpoints.load_optimizer: state_dict_to_device(self.optimizer.state_dict(), "cuda") before_optim_step_sanity_checks( @@ -791,12 +791,6 @@ def _init_model( model_builder=model_builder, ) - # Initialize rotary embeddings - for module in model.modules(): - if not isinstance(module, RotaryEmbedding): - continue - module.init_rotary_embeddings() - # Mark some parameters as tied self._mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) diff --git a/tools/converters/convert_hf_to_nanotron.py b/tools/converters/convert_hf_to_nanotron.py new file mode 100644 index 00000000..c12388b3 --- /dev/null +++ b/tools/converters/convert_hf_to_nanotron.py @@ -0,0 +1,257 @@ +""" +torchrun --nproc-per-node 1 tools/converters/convert_hf_to_nanotron.py --nanotron-checkpoint-path checkpoints/nanotron_pretrained_checkpoints/Nanotron-Llama-3.2-3B --pretrained-model-name-or-path meta-llama/Llama-3.2-3B +""" +import argparse +import json +from dataclasses import asdict +from pathlib import Path + +import torch +import yaml +from nanotron import logging +from nanotron.config import Config, GeneralArgs, LoggingArgs, ModelArgs, ParallelismArgs, TokenizerArgs +from nanotron.config.models_config import ExistingCheckpointInit +from nanotron.config.models_config import LlamaConfig as LlamaConfigNanotron +from nanotron.logging import log_rank, set_ranks_logging_level +from nanotron.models import build_model +from nanotron.models.llama import LlamaForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.serialize import TrainingMetadata, save_meta, save_weights +from nanotron.serialize.metadata import DataStageMetadata +from nanotron.trainer import mark_tied_parameters +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +logger = logging.get_logger(__name__) + +DEVICE = torch.device("cpu") +TORCH_DTYPE = torch.bfloat16 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory to store the converted Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="HuggingFace Model") + group.add_argument( + "--pretrained-model-name-or-path", + type=str, + required=True, + help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo on the Hugging Face Hub", + ) + + args = parser.parse_args() + + return args + + +def main(args): + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs()) + + # Load Llama3-8B HF model + log_rank( + f"Loading pretrained Llama3 Model: {args.pretrained_model_name_or_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + hf_model = AutoModelForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2" + ).to(DEVICE) + hf_config = hf_model.config + + # Set Nanotron LlamaConfig + nanotron_llama_config = LlamaConfigNanotron( + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + hidden_act=hf_config.hidden_act, + hidden_size=hf_config.hidden_size, + initializer_range=hf_config.initializer_range, + intermediate_size=hf_config.intermediate_size, + is_llama_config=True, + max_position_embeddings=hf_config.max_position_embeddings, + num_attention_heads=hf_config.num_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + num_key_value_heads=hf_config.num_key_value_heads, + pad_token_id=None, + pretraining_tp=hf_config.pretraining_tp, + rms_norm_eps=hf_config.rms_norm_eps, + rope_scaling=hf_config.rope_scaling, + rope_theta=hf_config.rope_theta, + tie_word_embeddings=hf_config.tie_word_embeddings, + use_cache=hf_config.use_cache, + vocab_size=hf_config.vocab_size, + ) + + # Init Llama3-8B Nanotron model + log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0) + nanotron_model = build_model( + model_builder=lambda: LlamaForTraining( + config=nanotron_llama_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, + ) + + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) + sanity_check(root_module=nanotron_model) + + # Copy params from HF to Nanotron + log_rank("Copying weights from HF model to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + with torch.no_grad(): + # Token embeddings + log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0) + assert ( + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape + == hf_model.model.embed_tokens.weight.shape + ) + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.copy_( + hf_model.model.embed_tokens.weight + ) + + # Decoder layers + for i in tqdm( + range(nanotron_llama_config.num_hidden_layers), + desc="Copying Hidden Layers", + total=nanotron_llama_config.num_hidden_layers, + ): + # Input layer norm + assert ( + hf_model.model.layers[i].input_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.copy_( + hf_model.model.layers[i].input_layernorm.weight + ) + + # Self attn + ## QKV + tmp_qkv_proj = torch.cat( + [ + hf_model.model.layers[i].self_attn.q_proj.weight, + hf_model.model.layers[i].self_attn.k_proj.weight, + hf_model.model.layers[i].self_attn.v_proj.weight, + ], + dim=0, + ) + assert tmp_qkv_proj.shape == nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.shape + nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.copy_(tmp_qkv_proj) + + ## O + assert ( + hf_model.model.layers[i].self_attn.o_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.copy_( + hf_model.model.layers[i].self_attn.o_proj.weight + ) + + # MLP + ## Gate Up Proj + tmp_gate_up_proj = torch.cat( + [ + hf_model.model.layers[i].mlp.gate_proj.weight, + hf_model.model.layers[i].mlp.up_proj.weight, + ], + dim=0, + ) + + assert tmp_gate_up_proj.shape == nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight.shape + nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight.copy_(tmp_gate_up_proj) + + ## Down Proj + assert ( + hf_model.model.layers[i].mlp.down_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.copy_( + hf_model.model.layers[i].mlp.down_proj.weight + ) + + # Post attn layer norm + assert ( + hf_model.model.layers[i].post_attention_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.copy_( + hf_model.model.layers[i].post_attention_layernorm.weight + ) + + # Last layer norm + log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape + nanotron_model.model.final_layer_norm.pp_block.weight.copy_(hf_model.model.norm.weight) + + # LM_Head + log_rank("Copying LM Head...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape + nanotron_model.model.lm_head.pp_block.weight.copy_(hf_model.lm_head.weight) + + log_rank("Copied weights from HF model to Nanotron model!", logger=logger, level=logging.INFO, rank=0) + # Store weights + nanotron_checkpoint_path = Path(args.nanotron_checkpoint_path) + save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=nanotron_checkpoint_path) + + # Store metadata + log_rank("Storing Nanotron model Configs and Metadata!", logger=logger, level=logging.INFO, rank=0) + training_metadata = TrainingMetadata( + last_train_step=0, + consumed_train_samples=0, + data_stages=[DataStageMetadata(name="Empty", consumed_train_samples=0, start_training_step=0)], + ) + save_meta( + root_folder=nanotron_checkpoint_path, parallel_context=parallel_context, training_metadata=training_metadata + ) + # Store Tokenizer into Nanotron Checkpoint folder + tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path) + tokenizer.save_pretrained(nanotron_checkpoint_path) + + # Store Config and Model Config files + with open(nanotron_checkpoint_path / "config.yaml", "w") as f: + config = Config( + general=GeneralArgs(project="Nanotron", run="Llama3"), + parallelism=parallel_config, + model=ModelArgs( + init_method=ExistingCheckpointInit(nanotron_checkpoint_path), + model_config=nanotron_llama_config, + ), + tokenizer=TokenizerArgs(nanotron_checkpoint_path), + ) + log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0) + yaml.dump(config.as_dict(), f) + + with open(nanotron_checkpoint_path / "model_config.json", "w") as f: + log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0) + json.dump(asdict(nanotron_llama_config), f) + + log_rank( + f"Checkpoint conversion finished, check {args.nanotron_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + +if __name__ == "__main__": + _args = get_args() + main(_args) \ No newline at end of file diff --git a/tools/converters/convert_nanotron_to_hf.py b/tools/converters/convert_nanotron_to_hf.py new file mode 100644 index 00000000..ca012c11 --- /dev/null +++ b/tools/converters/convert_nanotron_to_hf.py @@ -0,0 +1,221 @@ +""" +torchrun --nproc-per-node 1 tools/converters/convert_nanotron_to_hf.py --nanotron-checkpoint-path checkpoints/nanotron_pretrained_checkpoints/Nanotron-Llama-3.2-3B --hugging-face-checkpoint-path checkpoints/huggingface_converted/Converted-Nanotron-Llama-3.2-3B +""" +import argparse +import os +from dataclasses import asdict +from pathlib import Path + +import torch +from nanotron import logging +from nanotron.config import Config, LoggingArgs, ParallelismArgs, get_config_from_file +from nanotron.logging import log_rank, set_ranks_logging_level +from nanotron.models import build_model +from nanotron.models.llama import LlamaForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.serialize import load_weights +from nanotron.trainer import mark_tied_parameters +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.models.llama import LlamaConfig as LlamaConfigHF + +logger = logging.get_logger(__name__) + +DEVICE = torch.device("cpu") +TORCH_DTYPE = torch.bfloat16 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory with a Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="HuggingFace Model") + group.add_argument( + "--hugging-face-checkpoint-path", + type=str, + required=True, + help="A path to a directory to store the converted checkpoint", + ) + + args = parser.parse_args() + + return args + + +def main(args): + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs()) + + # Load Nanotron checkpoint config + log_rank( + f"Loading Nanotron checkpoint config file: {os.path.join(args.nanotron_checkpoint_path, 'config.yaml')}", + logger=logger, + level=logging.INFO, + rank=0, + ) + nanotron_config = get_config_from_file( + os.path.join(args.nanotron_checkpoint_path, "config.yaml"), config_class=Config, model_config_class=None + ) + nanotron_llama_config = nanotron_config.model.model_config + + # Init Llama3-8B Nanotron model + log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0) + + nanotron_model = build_model( + model_builder=lambda: LlamaForTraining( + config=nanotron_config.model.model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, + ) + + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) + sanity_check(root_module=nanotron_model) + + # Load Nanotron Checkpoint + log_rank("Loading Nanotron Llama3 Model...", logger=logger, level=logging.INFO, rank=0) + load_weights( + model=nanotron_model, parallel_context=parallel_context, root_folder=Path(args.nanotron_checkpoint_path) + ) + + # Build empty HF Model + log_rank("Init empty HF Llama3 Model", logger=logger, level=logging.INFO, rank=0) + hf_model = AutoModelForCausalLM.from_config( # WARN This takes a long time + config=LlamaConfigHF(**asdict(nanotron_llama_config)), + torch_dtype=TORCH_DTYPE, + attn_implementation="flash_attention_2", + ).to(DEVICE) + + # Copy params from Nanotron to HF + log_rank("Copying weights from Nanotron model to HF model...", logger=logger, level=logging.INFO, rank=0) + with torch.no_grad(): + # Token embeddings + log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0) + assert ( + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape + == hf_model.model.embed_tokens.weight.shape + ) + hf_model.model.embed_tokens.weight.copy_( + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight + ) + + # Decoder layers + for i in tqdm( + range(nanotron_llama_config.num_hidden_layers), + desc="Copying Hidden Layers", + total=nanotron_llama_config.num_hidden_layers, + ): + # Input layer norm + assert ( + hf_model.model.layers[i].input_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.shape + ) + hf_model.model.layers[i].input_layernorm.weight.copy_( + nanotron_model.model.decoder[i].pp_block.input_layernorm.weight + ) + + # Self attn + # Split Nanotrn qkv projection into q, k, v + q, k, v = torch.split( + nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight, + [ + nanotron_llama_config.num_attention_heads * nanotron_model.model.decoder[i].pp_block.attn.d_qk, + nanotron_llama_config.num_key_value_heads * nanotron_model.model.decoder[i].pp_block.attn.d_qk, + nanotron_llama_config.num_key_value_heads * nanotron_model.model.decoder[i].pp_block.attn.d_qk, + ], + ) + assert q.shape == hf_model.model.layers[i].self_attn.q_proj.weight.shape + assert k.shape == hf_model.model.layers[i].self_attn.k_proj.weight.shape + assert v.shape == hf_model.model.layers[i].self_attn.v_proj.weight.shape + + hf_model.model.layers[i].self_attn.q_proj.weight.copy_(q) + hf_model.model.layers[i].self_attn.k_proj.weight.copy_(k) + hf_model.model.layers[i].self_attn.v_proj.weight.copy_(v) + + ## O + assert ( + hf_model.model.layers[i].self_attn.o_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.shape + ) + hf_model.model.layers[i].self_attn.o_proj.weight.copy_( + nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight + ) + + # MLP + ## Gate Up Proj + gate_proj, up_proj = torch.split( + nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight, + split_size_or_sections=[nanotron_llama_config.intermediate_size, nanotron_llama_config.intermediate_size], + ) + assert gate_proj.shape == hf_model.model.layers[i].mlp.gate_proj.weight.shape + assert up_proj.shape == hf_model.model.layers[i].mlp.up_proj.weight.shape + + hf_model.model.layers[i].mlp.gate_proj.weight.copy_(gate_proj) + hf_model.model.layers[i].mlp.up_proj.weight.copy_(up_proj) + + ## Down Proj + assert ( + hf_model.model.layers[i].mlp.down_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.shape + ) + hf_model.model.layers[i].mlp.down_proj.weight.copy_( + nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight + ) + + # Post attn layer norm + assert ( + hf_model.model.layers[i].post_attention_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.shape + ) + hf_model.model.layers[i].post_attention_layernorm.weight.copy_( + nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight + ) + + # Last layer norm + log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape + hf_model.model.norm.weight.copy_(nanotron_model.model.final_layer_norm.pp_block.weight) + + # LM_Head + log_rank("Copying LM Head...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape + hf_model.lm_head.weight.copy_(nanotron_model.model.lm_head.pp_block.weight) + + log_rank("Copied weights from Nanotron model to HF model!", logger=logger, level=logging.INFO, rank=0) + # Store weights + log_rank("Storing HF model Checkpoint and Tokenizer!", logger=logger, level=logging.INFO, rank=0) + hf_model.save_pretrained(args.hugging_face_checkpoint_path, from_pt=True) + # Store tokenizer + tokenizer = AutoTokenizer.from_pretrained(nanotron_config.tokenizer.tokenizer_name_or_path) + tokenizer.save_pretrained(args.hugging_face_checkpoint_path) + + log_rank( + f"Checkpoint conversion finished, check {args.hugging_face_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + +if __name__ == "__main__": + _args = get_args() + main(_args) \ No newline at end of file diff --git a/tools/converters/delete/generate_hf_predictions.py b/tools/converters/delete/generate_hf_predictions.py new file mode 100644 index 00000000..065cdc0a --- /dev/null +++ b/tools/converters/delete/generate_hf_predictions.py @@ -0,0 +1,73 @@ +""" +torchrun --nproc-per-node 1 tools/converters/delete/generate_hf_predictions.py --pretrained-model-name-or-path meta-llama/Llama-3.2-3B +""" +import argparse +import os + +import numpy as np +import torch +from sklearn.metrics import accuracy_score +from transformers import AutoModelForCausalLM, AutoTokenizer + +TXT="Paris! Paris is the capital and most populous city of France, located in the north-central part of the country. It is a global center for art, fashion, cuisine, culture, and romance. Here's a brief overview: **History and Culture:**Paris has a rich history dating back to the 3rd century, with a blend of Roman, Gothic, Renaissance, and Art Nouveau influences. The city is famous for its iconic landmarks like the Eiffel Tower (built for the 1889 World's Fair), the Louvre Museum (home to the Mona Lisa), Notre-Dame Cathedral, and the Arc de Triomphe. **Art and Architecture:**Paris is renowned for its stunning architecture, with many beautiful bridges, gardens, and buildings. The city is also a hub for art, with numerous museums, galleries, and street performers. The Louvre, Musée d'Orsay, and Centre Pompidou are just a few of the many world-class museums. **Fashion and Cuisine:**Paris is considered the fashion capital of the world, with top designers like Chanel, Dior, and Louis Vuitton. The city is also famous for its exquisite cuisine, with popular dishes like escargots, croissants, baguettes, and cheese. Don't forget to try a classic French dessert like crème brûlée or macarons! **Romance and Entertainment:**Paris is often called the City of Light (La Ville Lumière) and the City of Love. It's a popular destination for couples and honeymooners, with its picturesque Seine River, charming streets, and cozy cafes. The city also hosts many festivals and events, including the French Open tennis tournament, the Tour de France, and the Rock en Seine music festival. **Economy and Education:** Paris is a global economic hub, with many multinational companies, startups, and universities. The city is home to some of the world's top universities, including the Sorbonne and École des Hautes Études en Sciences Sociales (EHESS). **Tourism:** Paris is one of the most visited cities in the world, attracting over 23 million tourists annually. Visitors come to experience the city's unique blend of history, culture, art, fashion, and romance. In summary, Paris is a vibrant, elegant, and enchanting city that offers something for everyone: history, art, fashion, cuisine, romance, and entertainment." +SEQ_LENGTH = 256 # For truncating the TXT if GPU can't fit too many tokens + +DEVICE = torch.device("cuda") +TORCH_DTYPE = torch.bfloat16 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="HuggingFace Model") + group.add_argument( + "--pretrained-model-name-or-path", + type=str, + required=True, + help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo on the Hugging Face Hub", + ) + + args = parser.parse_args() + + return args + + +def main(args): + + model = AutoModelForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=TORCH_DTYPE, + attn_implementation="flash_attention_2", + ).to(DEVICE).eval() + + tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path) + tokens = tokenizer(TXT, return_tensors="pt", truncation=True, max_length=(SEQ_LENGTH + 1))["input_ids"].to(DEVICE) + inputs = tokens[:, :-1] + + with torch.no_grad(): + output = model(inputs) + + predicted_tokens = [5, 27, 34] # Index of the predictions to compare across models + term_cols = int(os.get_terminal_size().columns / 3) + + for predicted_token in predicted_tokens: + + print("\n", "=" * term_cols, f"Predictions of token {predicted_token}", "=" * term_cols) + next_tokens = torch.softmax(output.logits[0, predicted_token, :], -1) + topk_next_tokens = torch.topk(next_tokens, 10) + + print( + *[ + f"[HF Model] Next token: {idx.item()}, probability: {prob}" + for idx, prob in zip(topk_next_tokens.indices, topk_next_tokens.values) + ], + sep="\n", + ) + + # Compute accuracy + predictions = np.argmax(output.logits.to(torch.float).cpu(), axis=2).flatten().tolist() + labels = tokens.cpu().flatten()[1:].tolist() + print(f"\nAccuracy: {accuracy_score(labels, predictions)}") + +if __name__ == "__main__": + _args = get_args() + main(_args) diff --git a/tools/converters/delete/generate_nanotron_predictions.py b/tools/converters/delete/generate_nanotron_predictions.py new file mode 100644 index 00000000..23c4803c --- /dev/null +++ b/tools/converters/delete/generate_nanotron_predictions.py @@ -0,0 +1,127 @@ +""" +torchrun --nproc-per-node 1 tools/converters/delete/generate_nanotron_predictions.py --tp 1 --nanotron-checkpoint-path /capstor/scratch/cscs/asolergi/nanotron/checkpoints/nanotron_pretrained_checkpoints/Nanotron-Llama-3.2-3B +""" +import argparse +import os +from pathlib import Path + +import nanotron.distributed as dist +import numpy as np +import torch +from nanotron.config import Config, ParallelismArgs, get_config_from_file +from nanotron.models import build_model +from nanotron.models.llama import LlamaForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.parallel.pipeline_parallel.engine import AllForwardAllBackwardPipelineEngine +from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode +from nanotron.serialize import load_weights +from nanotron.trainer import mark_tied_parameters +from sklearn.metrics import accuracy_score +from transformers import AutoTokenizer + +TXT="Paris! Paris is the capital and most populous city of France, located in the north-central part of the country. It is a global center for art, fashion, cuisine, culture, and romance. Here's a brief overview: **History and Culture:**Paris has a rich history dating back to the 3rd century, with a blend of Roman, Gothic, Renaissance, and Art Nouveau influences. The city is famous for its iconic landmarks like the Eiffel Tower (built for the 1889 World's Fair), the Louvre Museum (home to the Mona Lisa), Notre-Dame Cathedral, and the Arc de Triomphe. **Art and Architecture:**Paris is renowned for its stunning architecture, with many beautiful bridges, gardens, and buildings. The city is also a hub for art, with numerous museums, galleries, and street performers. The Louvre, Musée d'Orsay, and Centre Pompidou are just a few of the many world-class museums. **Fashion and Cuisine:**Paris is considered the fashion capital of the world, with top designers like Chanel, Dior, and Louis Vuitton. The city is also famous for its exquisite cuisine, with popular dishes like escargots, croissants, baguettes, and cheese. Don't forget to try a classic French dessert like crème brûlée or macarons! **Romance and Entertainment:**Paris is often called the City of Light (La Ville Lumière) and the City of Love. It's a popular destination for couples and honeymooners, with its picturesque Seine River, charming streets, and cozy cafes. The city also hosts many festivals and events, including the French Open tennis tournament, the Tour de France, and the Rock en Seine music festival. **Economy and Education:** Paris is a global economic hub, with many multinational companies, startups, and universities. The city is home to some of the world's top universities, including the Sorbonne and École des Hautes Études en Sciences Sociales (EHESS). **Tourism:** Paris is one of the most visited cities in the world, attracting over 23 million tourists annually. Visitors come to experience the city's unique blend of history, culture, art, fashion, and romance. In summary, Paris is a vibrant, elegant, and enchanting city that offers something for everyone: history, art, fashion, cuisine, romance, and entertainment." +SEQ_LENGTH = 256 # For truncating the TXT if GPU can't fit too many tokens + +DEVICE = torch.device("cuda") +TORCH_DTYPE = torch.bfloat16 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory containing a Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="Nanotron Parallelism") + group.add_argument("--tp", type=int, required=True, help="Tensor Parallelism Degree of the Nanotron Checkpoint") + + args = parser.parse_args() + + return args + + +def main(args): + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs( + dp=1, + pp=1, + tp=args.tp, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + assert ( + parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE + and parallel_config.tp_linear_async_communication is False + ) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + RANK = dist.get_rank(parallel_context.world_pg) + + nanotron_config = get_config_from_file( + os.path.join(args.nanotron_checkpoint_path, "config.yaml"), config_class=Config, model_config_class=None + ) + + model = build_model( + model_builder=lambda: LlamaForTraining( + config=nanotron_config.model.model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, # TODO Check with different parallelism if cpu is available + ) + + mark_tied_parameters(model=model, parallel_context=parallel_context) + sanity_check(root_module=model) + + # Load checkpoint directly in memory and then only keep the state dictionary + load_weights(model=model, parallel_context=parallel_context, root_folder=Path(args.nanotron_checkpoint_path)) + + tokenizer = AutoTokenizer.from_pretrained(nanotron_config.tokenizer.tokenizer_name_or_path) + tokens = tokenizer(TXT, return_tensors="pt", truncation=True, max_length=(SEQ_LENGTH + 1))["input_ids"].to(DEVICE) + inputs = {"input_ids": tokens[:, :-1], "input_mask": torch.ones((1, SEQ_LENGTH), device=DEVICE)} + + model.eval() + + with torch.no_grad(): + output = model.model(**inputs) + + if not RANK: + predicted_tokens = [5, 27, 34] # Index of the predictions to compare across models + term_cols = int(os.get_terminal_size().columns / 3) + + for predicted_token in predicted_tokens: + + print("\n", "=" * term_cols, f"Predictions of token {predicted_token}", "=" * term_cols) + next_tokens = torch.softmax(output.transpose(0, 1)[0, predicted_token, :], -1) + topk_next_tokens = torch.topk(next_tokens, 10) + + print( + *[ + f"[Nanotron Model] Next token: {idx.item()}, probability: {prob}" + for idx, prob in zip(topk_next_tokens.indices, topk_next_tokens.values) + ], + sep="\n", + ) + + # Compute accuracy + predictions = np.argmax(output.transpose(0, 1).to(torch.float).cpu(), axis=2).flatten().tolist() + labels = tokens.cpu().flatten()[1:].tolist() + print(f"\nAccuracy: {accuracy_score(labels, predictions)}") + +if __name__ == "__main__": + _args = get_args() + main(_args)