From e34cabb1d1f7ee6ca0eae162c398e81583782757 Mon Sep 17 00:00:00 2001 From: emozilla Date: Sun, 11 Aug 2024 00:53:01 +0000 Subject: [PATCH] non-packed sft w/ ring attention --- run_train.py | 4 + src/nanotron/config/config.py | 2 +- src/nanotron/data/chat_dataset.py | 76 ++++-- src/nanotron/data/chat_tokenizer.py | 48 +++- src/nanotron/data/collator.py | 60 ++++- src/nanotron/data/dataloader_builder.py | 6 +- src/nanotron/models/llama.py | 254 ++++++++++++++++-- .../ring_flash_attn/zigzag_ring_flash_attn.py | 4 + src/nanotron/serialize/weights.py | 2 +- src/nanotron/trainer.py | 18 ++ tools/llama3/convert_nanotron_to_hf.py | 3 +- 11 files changed, 426 insertions(+), 51 deletions(-) diff --git a/run_train.py b/run_train.py index d5e3cfd7..2dcc6f0b 100644 --- a/run_train.py +++ b/run_train.py @@ -189,6 +189,8 @@ def get_dataloader_from_data_stage( conversation_column_name=data.dataset.conversation_column_name, dp_rank=trainer.parallel_context.dp_pg.rank(), dp_ranks_size=trainer.parallel_context.dp_pg.size(), + sp_ranks_size=trainer.parallel_context.sp_pg.size(), + seed=data.seed, ) # Prepare dataloader @@ -198,6 +200,8 @@ def get_dataloader_from_data_stage( input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, ) + + return train_dataloader else: raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}") diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index cf1a6ae9..c73fd494 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -127,7 +127,7 @@ def __post_init__(self): class DataArgs: """Arguments related to the data and data files processing""" - dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]] + dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs, ChatDatasetsArgs]] seed: Optional[int] num_loading_workers: Optional[int] = 1 diff --git a/src/nanotron/data/chat_dataset.py b/src/nanotron/data/chat_dataset.py index fba0001e..6b4aeec8 100644 --- a/src/nanotron/data/chat_dataset.py +++ b/src/nanotron/data/chat_dataset.py @@ -10,9 +10,13 @@ build_position_ids, build_position_ids_dummy, ) +from nanotron import logging +from nanotron.logging import log_rank from torch.utils.data import IterableDataset from transformers import AutoTokenizer +logger = logging.get_logger(__name__) + class ChatDataset(IterableDataset): """ @@ -44,6 +48,7 @@ def __init__( split: str = "train", dp_rank: int = 0, dp_ranks_size: int = 1, + sp_ranks_size: int = 1, skip_num_samples: int = None, # TODO(tj.solergibert) Delete, check later comment seed: int = 1234, ) -> None: @@ -60,6 +65,7 @@ def __init__( self.skip_num_samples = skip_num_samples self.seed = seed self.pack_samples = pack_samples + self.sp_chunks = sp_ranks_size * 2 if sp_ranks_size > 1 else 1 # Load, split and shuffle dataset self.dataset = load_dataset(dataset_path, split=split, streaming=True) @@ -87,6 +93,7 @@ def __iter__(self): buffer_tokens: List[int] = [] buffer_is_completition: List[int] = [] buffer_lengths: List[int] = [] + num_samples = 0 while True: for sample in iter(self.dataset): @@ -98,39 +105,68 @@ def __iter__(self): self.debug_tokenizer.apply_chat_template(sample["conversations"]) == tokens[:-1] ), f'{self.debug_tokenizer.apply_chat_template(sample["conversations"])}\n\n{tokens[:-1]}' - buffer_tokens.extend(tokens) - buffer_is_completition.extend(is_completition) - buffer_lengths.append(len(tokens)) + if self.pack_samples: + buffer_tokens.extend(tokens) + buffer_is_completition.extend(is_completition) + buffer_lengths.append(len(tokens)) - if len(buffer_tokens) > max_buffer_token_len or not self.pack_samples: # Can't pack more samples, yield - if self.pack_samples: + if len(buffer_tokens) > max_buffer_token_len: # Can't pack more samples, yield # Pop last sample from buffers sample_tokens = buffer_tokens[: -len(tokens)] sample_completitions = buffer_is_completition[: -len(tokens)] sample_lengths = buffer_lengths[:-1] - # TODO(tj.solergibert) Delete (debug) - assert len(sample_tokens) == len(sample_completitions) == sum(sample_lengths) + # TODO(tj.solergibert) Delete (debug) + assert len(sample_tokens) == len(sample_completitions) == sum(sample_lengths) - # Reset tokens buffers - buffer_tokens = tokens.copy() - buffer_is_completition = is_completition.copy() - buffer_lengths = [len(tokens)] + # Reset tokens buffers + buffer_tokens = tokens.copy() + buffer_is_completition = is_completition.copy() + buffer_lengths = [len(tokens)] - # TODO(tj.solergibert) Delete (debug), just 4 switching the training only on completitions setting - sample_completitions = self.create_labels(sample_tokens, sample_completitions) + # TODO(tj.solergibert) Delete (debug), just 4 switching the training only on completitions setting + sample_completitions = self.create_labels(sample_tokens, sample_completitions) - # TODO(tj.solergibert) Delete (debug), just 4 switching the remove cross-attention setting - position_ids = self.create_position_ids(sample_lengths, self.sequence_length) + # TODO(tj.solergibert) Delete (debug), just 4 switching the remove cross-attention setting + position_ids = self.create_position_ids(sample_lengths, self.sequence_length) - # TODO(tj.solergibert) Delete (debug) - # assert len(sample_tokens) <= max_buffer_token_len + # TODO(tj.solergibert) Delete (debug) + # assert len(sample_tokens) <= max_buffer_token_len + + yield { + "input_ids": np.array(sample_tokens, dtype=np.int32), + "label_mask": np.array([1 if x else 0 for x in sample_completitions], dtype=np.int32), + "position_ids": position_ids, + } + else: + # TODO(tj.solergibert) Delete (debug), just 4 switching the training only on completitions setting + is_completition = self.create_labels(tokens, is_completition) + input_mask = ([1] * len(tokens)) + + rem = len(tokens) % self.sp_chunks + if rem != 0: + pad_amount = self.sp_chunks - rem + tokens.extend([self.chat_tokenizer.tokenizer.pad_token_id] * pad_amount) + is_completition.extend([False] * pad_amount) + input_mask.extend([0] * pad_amount) + + if self.sp_chunks > 1: + # sequence needs to be of length (closest multiple of 2 * sp_pg.size()) + 1 + # + 1 is so we have (closest multiple of 2 * sp_pg.size()) after shifting by one to get causal prediction + tokens.append(self.chat_tokenizer.tokenizer.pad_token_id) + is_completition.append(False) + input_mask.append(0) + + assert len(tokens) == len(input_mask) + assert len(input_mask) == len(is_completition) + label_mask = [1 if x else 0 for x in is_completition] yield { - "input_ids": np.array(sample_tokens, dtype=np.int32), - "is_completitions": np.array(sample_completitions, dtype=np.bool_), - "position_ids": position_ids, + "input_ids": np.array(tokens, dtype=np.int32), + "input_mask": np.array(input_mask, dtype=np.int32), + "label_mask": np.array(label_mask, dtype=np.int32), } + num_samples += 1 # TODO(tj.solergibert) Change for log_rank (log_rank is problematic with JupyterNB) print("Consumed all samples, dataset is being re-looped.") \ No newline at end of file diff --git a/src/nanotron/data/chat_tokenizer.py b/src/nanotron/data/chat_tokenizer.py index 47762e3b..77430c68 100644 --- a/src/nanotron/data/chat_tokenizer.py +++ b/src/nanotron/data/chat_tokenizer.py @@ -1,21 +1,38 @@ from typing import List, Tuple +from enum import Enum, auto from transformers import AutoTokenizer +class ChatFormat(Enum): + LLAMA3 = auto() + CHATML = auto() + + class ChatTokenizer: """ The ChatTokenizer encodes a conversation applying the Llama3 Chat Template and returns the role (Either User or Assistant) of each token + Args: tokenizer_name_or_path (str): A path to a directory containing vocabulary files required by the tokenizer or the model id of a predefined tokenizer hosted inside a model repo on the Hugging Face Hub. """ - def __init__(self, tokenizer_name_or_path: str): + def __init__(self, tokenizer_name_or_path: str, chat_format: ChatFormat = ChatFormat.LLAMA3): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + self._chat_format = chat_format + if chat_format == ChatFormat.LLAMA3: + self._header_start = "<|start_header_id|>" + self._header_end = "<|end_header_id|>\n\n" + self._turn_end = "<|eot_id|>" + elif chat_format == ChatFormat.CHATML: + self._header_start = "<|im_start|>" + self._header_end = "\n" + self._turn_end = "<|im_end|>\n" + # Add pad token if necessary if self.tokenizer.pad_token is None: - self.tokenizer.add_special_tokens({"pad_token": "<|eot_id|>"}) + self.tokenizer.add_special_tokens({"pad_token": self._turn_end}) def __call__(self, conversation: List[dict]) -> Tuple[List[int], List[bool]]: """ @@ -27,11 +44,15 @@ def __call__(self, conversation: List[dict]) -> Tuple[List[int], List[bool]]: conversation: [ { "from": "system", "value": "You are an AI assistant that follows instruction extremely well. Help as much as you can."}, { "from": "human", "value": "Answer the following question: - number is 54 - debutteam is pittsburgh steelers - draftpick is 166 - birth date is 24 may 1982 - weight is 243 - nfl is wal475737 - debutyear is 2005 - finalteam is new york sentinels - statlabel is tackles sacks interceptions - heightin is 3 - statvalue is 9 0.0 1 - heightft is 6 - college is temple - birth place is pottstown , pennsylvania - draftyear is 2005 - position is linebacker - draftround is 5 - finalyear is 2009 Given the details above, guess who could this information be about.\nAnswer:"}, { "from": "gpt", "value": "The information provided seems to refer to Rian Wallace, a former NFL player."} ] + After applying chat template: <|begin_of_text|><|start_header_id|>system<|end_header_id|> + You are an AI assistant that follows instruction extremely well. Help as much as you can.<|eot_id|><|start_header_id|>human<|end_header_id|> + Answer the following question: - number is 54 - debutteam is pittsburgh steelers - draftpick is 166 - birth date is 24 may 1982 - weight is 243 - nfl is wal475737 - debutyear is 2005 - finalteam is new york sentinels - statlabel is tackles sacks interceptions - heightin is 3 - statvalue is 9 0.0 1 - heightft is 6 - college is temple - birth place is pottstown , pennsylvania - draftyear is 2005 - position is linebacker - draftround is 5 - finalyear is 2009 Given the details above, guess who could this information be about. Answer:<|eot_id|><|start_header_id|>gpt<|end_header_id|> + The information provided seems to refer to Rian Wallace, a former NFL player.<|eot_id|> returns: tokens (List[int]): A list of tokens e.g. [128000, 128006, 9125, 128007, 271, 2675, 527, ..., 12873, 2851, 13, 128009, 128001] @@ -60,9 +81,11 @@ def encode_message(self, message: dict) -> Tuple[List[int], List[int]]: # single format and document it properly rather than supporting multiple formats, as each DATASET will need a different # ChatTokenizer and the idea is that all Datasets share the same ChatTokenizer + role, is_input = self._get_role(message) + # Encode header tokens = self.tokenizer.encode( - f"<|start_header_id|>{message['from']}<|end_header_id|>\n\n", add_special_tokens=False + f"{self._header_start}{role}{self._header_end}", add_special_tokens=False ) is_completitions = [False] * len(tokens) @@ -70,9 +93,22 @@ def encode_message(self, message: dict) -> Tuple[List[int], List[int]]: tokens.extend(self.tokenizer.encode(message["value"].strip(), add_special_tokens=False)) # Append <|eot_id|> token - tokens.extend(self.tokenizer.encode("<|eot_id|>", add_special_tokens=False)) + tokens.extend(self.tokenizer.encode(self._turn_end, add_special_tokens=False)) # True if token belongs to assistant answer, False otherwise - is_completitions.extend([True if message["from"] == "gpt" else False] * (len(tokens) - len(is_completitions))) + is_completitions.extend([not is_input] * (len(tokens) - len(is_completitions))) + + return tokens, is_completitions - return tokens, is_completitions \ No newline at end of file + def _get_role(self, message: dict) -> Tuple[str, bool]: + """ + Return the canonical role for a given message, as well as if its value + should be considered input (and therefore not trained on) + """ + role = message["from"] + if role == "gpt" or role == "assistant": + return "assistant", False + elif role == "human": + return "user", True + else: + return role, True \ No newline at end of file diff --git a/src/nanotron/data/collator.py b/src/nanotron/data/collator.py index 5fdf840c..33f6b4ed 100644 --- a/src/nanotron/data/collator.py +++ b/src/nanotron/data/collator.py @@ -130,7 +130,7 @@ def __call__(self, examples: List[Dict[str, List[int]]]) -> Dict[str, Union[torc } input_ids = np.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s) - is_completitions = np.vstack([examples[i]["is_completitions"] for i in range(len(examples))]) # (b, s) + label_mask = np.vstack([examples[i]["label_mask"] for i in range(len(examples))]) # (b, s) position_ids = np.vstack([examples[i]["position_ids"] for i in range(len(examples))]) # (b, s) result: Dict[str, Union[np.ndarray, TensorPointer]] = {} @@ -148,7 +148,63 @@ def __call__(self, examples: List[Dict[str, List[int]]]) -> Dict[str, Union[torc # Process labels: shift them to the left. if current_pp_rank == self.output_pp_rank: result["label_ids"] = input_ids[:, 1:] - result["label_mask"] = is_completitions[:, 1:] + result["label_mask"] = label_mask[:, 1:] + + # Cast np.array to torch.Tensor + result = {k: v if isinstance(v, TensorPointer) else torch.from_numpy(v) for k, v in result.items()} + return result + + +@dataclasses.dataclass +class DataCollatorForUnpackedSFT: + """ + Data collator used with Chat Dataset. + - input_pp_rank: Discards last input id token + - output_pp_rank: Discards first label id token + - other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data. + """ + + input_pp_rank: int + output_pp_rank: int + parallel_context: ParallelContext + + def __call__(self, examples: List[Dict[str, List[int]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + # Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data. + assert len(examples) == 1 + + current_pp_rank = dist.get_rank(self.parallel_context.pp_pg) + if current_pp_rank not in [ + self.input_pp_rank, + self.output_pp_rank, + ]: + assert all(len(example) == 0 for example in examples) + return { + "input_ids": TensorPointer(group_rank=self.input_pp_rank), + "input_mask": TensorPointer(group_rank=self.input_pp_rank), + "label_ids": TensorPointer(group_rank=self.output_pp_rank), + "label_mask": TensorPointer(group_rank=self.output_pp_rank), + } + + input_ids = np.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s) + input_mask = np.vstack([examples[i]["input_mask"] for i in range(len(examples))]) # (b, s) + label_mask = np.vstack([examples[i]["label_mask"] for i in range(len(examples))]) # (b, s) + + result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {} + + result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank) + result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank) + result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank) + result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank) + + # Process inputs + if current_pp_rank == self.input_pp_rank: + result["input_ids"] = input_ids[:, :-1] + result["input_mask"] = input_mask[:, :-1] + + # Process labels: shift them to the left. + if current_pp_rank == self.output_pp_rank: + result["label_ids"] = input_ids[:, 1:] + result["label_mask"] = label_mask[:, 1:] # Cast np.array to torch.Tensor result = {k: v if isinstance(v, TensorPointer) else torch.from_numpy(v) for k, v in result.items()} diff --git a/src/nanotron/data/dataloader_builder.py b/src/nanotron/data/dataloader_builder.py index 8303882d..c9ca7f88 100644 --- a/src/nanotron/data/dataloader_builder.py +++ b/src/nanotron/data/dataloader_builder.py @@ -1,6 +1,6 @@ import nanotron.distributed as dist from nanotron import logging -from nanotron.data.collator import DataCollatorForSFT, NanosetDataCollatorForCLM +from nanotron.data.collator import DataCollatorForSFT, DataCollatorForUnpackedSFT, NanosetDataCollatorForCLM from nanotron.dataloader import ( EmptyInfiniteDataset, get_dataloader_worker_init, @@ -80,6 +80,10 @@ def build_chat_dataloader( input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, parallel_context=parallel_context, + ) if dataset.pack_samples else DataCollatorForUnpackedSFT( + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + parallel_context=parallel_context, ) dp_rank = parallel_context.dp_pg.rank() diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 5bdaaf4b..5c92e1e7 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,8 +14,9 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Tuple +import math import torch from torch import nn from torch.utils.checkpoint import CheckpointFunction @@ -137,30 +138,165 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +def _compute_default_rope_parameters( + config: Optional[LlamaConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, attention_factor + +def _compute_llama3_parameters( + config: LlamaConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies for llama 3.1. + + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # Gets the default RoPE parameters + inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) + + factor = config.rope_scaling["factor"] # `8` in the original implementation + low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation + high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation + old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + + return inv_freq_llama, attention_factor + +ROPE_INIT_FUNCTIONS = { + "default": _compute_default_rope_parameters, + "llama3": _compute_llama3_parameters, +} + class LlamaRotaryEmbedding(nn.Module): - def __init__(self, dim: int, end: int, theta: float = 500000.0): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LlamaConfig] = None, + ): super().__init__() - self.dim = dim - self.end = end - self.theta = theta - self.init_rotary_embeddings() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.45" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings - def init_rotary_embeddings(self): - inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda") / self.dim)) + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward( - self, - x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk] - position_ids: Optional[torch.LongTensor], # [batch_size, seq_length] - ): - # x: [bs, num_attention_heads, seq_len, head_size] - # print("rotary") + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -168,6 +304,11 @@ def forward( emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -195,6 +336,70 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed +## Copy from transformers. Non interleaved version of RoPE. Will be refactored later +# def rotate_half(x): +# """Rotates half the hidden dims of the input.""" +# x1 = x[..., : x.shape[-1] // 2] +# x2 = x[..., x.shape[-1] // 2 :] +# return torch.cat((-x2, x1), dim=-1) + + +# class LlamaRotaryEmbedding(nn.Module): +# def __init__(self, dim: int, end: int, theta: float = 500000.0): +# super().__init__() +# self.dim = dim +# self.end = end +# self.theta = theta +# self.init_rotary_embeddings() + +# def init_rotary_embeddings(self): +# inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda") / self.dim)) +# self.register_buffer("inv_freq", inv_freq, persistent=False) + +# @torch.no_grad() +# def forward( +# self, +# x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk] +# position_ids: Optional[torch.LongTensor], # [batch_size, seq_length] +# ): +# # x: [bs, num_attention_heads, seq_len, head_size] +# # print("rotary") +# inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) +# position_ids_expanded = position_ids[:, None, :].float() +# # Force float32 since bfloat16 loses precision on long contexts +# # See https://github.com/huggingface/transformers/pull/29285 +# device_type = x.device.type +# device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" +# with torch.autocast(device_type=device_type, enabled=False): +# freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) +# emb = torch.cat((freqs, freqs), dim=-1) +# cos = emb.cos() +# sin = emb.sin() +# return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2): +# """Applies Rotary Position Embedding to the query and key tensors. +# Args: +# q (`torch.Tensor`): The query tensor. +# k (`torch.Tensor`): The key tensor. +# cos (`torch.Tensor`): The cosine part of the rotary embedding. +# sin (`torch.Tensor`): The sine part of the rotary embedding. +# unsqueeze_dim (`int`, *optional*, defaults to 1): +# The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and +# sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note +# that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and +# k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes +# cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have +# the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. +# Returns: +# `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. +# """ +# cos = cos.unsqueeze(unsqueeze_dim) +# sin = sin.unsqueeze(unsqueeze_dim) +# q_embed = (q * cos) + (rotate_half(q) * sin) +# k_embed = (k * cos) + (rotate_half(k) * sin) +# return q_embed, k_embed class GLUActivation(nn.Module): def __init__(self, act_fn_name: str): @@ -434,8 +639,11 @@ def __init__( else: self.rotary_embedding = LlamaRotaryEmbedding( dim=self.d_qk, - end=config.max_position_embeddings, - theta=config.rope_theta, + max_position_embeddings=config.max_position_embeddings, + #end=config.max_position_embeddings, + #theta=config.rope_theta, + base=config.rope_theta, + config=config, ) self.rope_interleaved = config.rope_interleaved @@ -948,6 +1156,9 @@ def forward_with_hidden_states( input_ids, input_mask, position_ids = zigzag_split( rank, world_size, input_ids, input_mask, position_ids ) + # logger.log(logging.INFO, f"input_ids: {input_ids}") + # logger.log(logging.INFO, f"input_mask: {input_mask}") + # logger.log(logging.INFO, f"position_ids: {position_ids}") else: position_ids = TensorPointer(input_ids.group_rank) # all tensors are optional as most ranks don't need anything from the dataloader. @@ -1033,9 +1244,14 @@ def forward( world_size = self.sp_pg.size() rank = dist.get_rank(self.sp_pg) label_ids, label_mask = zigzag_split(rank, world_size, label_ids, label_mask) + # logger.log(logging.INFO, f"label_ids: {label_ids}") + # logger.log(logging.INFO, f"label_mask: {label_mask}") loss = sharded_cross_entropy( sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float ).transpose(0, 1) + #if torch.isnan(loss.detach()).any().item(): + # logger.log(logging.INFO, f"label_ids: {label_ids}") + # logger.log(logging.INFO, f"loss: {loss}") # TODO @thomasw21: It's unclear what kind of normalization we want to do. loss = masked_mean(loss, label_mask, dtype=torch.float) # I think indexing causes a sync we don't actually want diff --git a/src/nanotron/parallel/ring_flash_attn/zigzag_ring_flash_attn.py b/src/nanotron/parallel/ring_flash_attn/zigzag_ring_flash_attn.py index 2b8ab712..40a22088 100644 --- a/src/nanotron/parallel/ring_flash_attn/zigzag_ring_flash_attn.py +++ b/src/nanotron/parallel/ring_flash_attn/zigzag_ring_flash_attn.py @@ -14,6 +14,7 @@ def zigzag_ring_flash_attn_forward( dropout_p=0, causal=True, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, ): @@ -36,6 +37,7 @@ def forward(q, k, v, causal): softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=True and dropout_p > 0, ) @@ -87,6 +89,7 @@ def zigzag_ring_flash_attn_backward( dropout_p=0, causal=True, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, ): @@ -126,6 +129,7 @@ def backward(doubt, q, k, v, out, softmax_lse, causal): softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, rng_state=None, diff --git a/src/nanotron/serialize/weights.py b/src/nanotron/serialize/weights.py index 621c6f4c..41c2ec02 100644 --- a/src/nanotron/serialize/weights.py +++ b/src/nanotron/serialize/weights.py @@ -302,7 +302,7 @@ def load_weights( finally: assert ( current_checkpoint_version == checkpoint_version - ), f"Checkpoint version mismatch at {shards_path[0]}." + ), f"Checkpoint version mismatch at {shards_path[0]}. Got {current_checkpoint_version} but expected {checkpoint_version}" if checkpoint_version <= CHECKPOINT_VERSION: load_sharded_param_latest( diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 76461481..aa134654 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -19,6 +19,7 @@ cast, ) +from nanotron.config.config import ChatDatasetsArgs import torch from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader @@ -253,6 +254,23 @@ def __init__( # NOTE: the dataloader currently in use for the current training stage self.current_dataloader: Optional[DataLoader] = None + # NOTE(tj.solergibert) Flatten batch size in SFT training + if isinstance(self.config.data_stages[0].data.dataset, ChatDatasetsArgs) and self.micro_batch_size != 1: + if self.config.data_stages[0].data.dataset.pack_samples: + self.sequence_length = self.micro_batch_size * self.config.tokens.sequence_length + self.micro_batch_size = 1 + self.global_batch_size = ( + self.micro_batch_size * self.n_micro_batches_per_batch * self.parallel_context.dp_pg.size() + ) + log_rank( + f"Flattening Batch dimension for SFT training. global_batch_size: {self.global_batch_size}, micro_batch_size: {self.micro_batch_size}, sequence_length: {self.sequence_length}", + logger=logger, + level=logging.INFO, + rank=0, + ) + else: + raise ValueError("ChatDataset without sample packing requires micro_batch_size=1") + self.post_init() def pre_init(self): diff --git a/tools/llama3/convert_nanotron_to_hf.py b/tools/llama3/convert_nanotron_to_hf.py index 4cfaa9fa..789b8b2c 100644 --- a/tools/llama3/convert_nanotron_to_hf.py +++ b/tools/llama3/convert_nanotron_to_hf.py @@ -51,12 +51,13 @@ def get_args(): def main(args): # Init Nanotron Parallel Utilities - parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) + parallel_config = ParallelismArgs(dp=1, pp=1, tp=1, sp=1) parallel_context = ParallelContext( data_parallel_size=parallel_config.dp, pipeline_parallel_size=parallel_config.pp, tensor_parallel_size=parallel_config.tp, + sequence_parallel_size=parallel_config.sp, ) set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs())