diff --git a/docs/nanoset.md b/docs/nanoset.md index 02649bd0..61393438 100644 --- a/docs/nanoset.md +++ b/docs/nanoset.md @@ -1,41 +1,42 @@ # Nanosets -Nanotron incorporates [`Nanosets`](../src/nanotron/data/nanoset.py), a kind of datasets based on [numpy memory-mapped arrays](https://numpy.org/doc/stable/reference/generated/numpy.memmap.html). `Nanosets` are capable of serving batches from files containing pre-tokenized datasets. They allow reading tokens from one or multiple datasets and even specifying the weight of each dataset when building batches. +Nanotron incorporates [`Nanosets`](../src/nanotron/data/nanoset.py), a dataset for processing tokenized documents with [`datatrove`](https://github.com/huggingface/datatrove). They allow reading tokens from one or multiple datasets and even specifying the weight of each dataset when building batches. ## Install To use `Nanosets`, it's necessary to install Nanotron with the `nanosets` flavor. ``` -pip install -e '.[nanosets]' +pip install nanotron[nanosets] ``` This will install the following dependencies: -- `transformers`: To tokenize the datasets -- `datasets`: To preprocess the datasets +- `datatrove`: To preprocess the datasets - `numba`: To compile helper functions in order to speed up the creation of `Nanosets` +- `transformers`: For the tokenizers ## Data pre-processing -To use these datasets, first, we need to preprocess the data. The input format can either be a column of a Hugging Face Dataset or a .json file containing a text sample per line. For example: +To use this dataset, first, we need to preprocess the data using `datatrove`'s `DocumentTokenizer` pipeline. We invite you to take a look at `datatrove`, since it contains multiple features that allow, for example, filter out documents based on specific rules/criteria, extract text content from raw formats or scheduling the preprocessing in a Slurm cluster. We have also added a simple script capable of tokenizing datasets. -
-{"src": "www.nvidia.com", "text": "The quick brown fox", "type": "Eng", "id": "0", "title": "First Part"}
-{"src": "The Internet", "text": "jumps over the lazy dog", "type": "Eng", "id": "42", "title": "Second Part"}
-
- -The preprocessing is done using the [`tools/preprocess_data.py`](../tools/preprocess_data.py) script. Below we show an example for processing a corpus with the Llama2 tokenizer. +The preprocessing is done using the [`tools/preprocess_data.py`](../tools/preprocess_data.py) script. The input format can either be a Hugging Face Dataset, a path to a `.jsonl` or a path to a folder containing multiple `.jsonl` files. Below we show an example for processing a Hugging Face Dataset from the Hub with the Llama3 tokenizer.
-torchrun --nproc-per-node 16 tools/preprocess_data.py \
-       --input HuggingFaceH4/testing_alpaca_small \
-       --split train \
-       --column completion \
-       --output-prefix datasets/testing_alpaca_small \
-       --tokenizer-name-or-path openai-community/gpt2
+python3 tools/preprocess_data.py \
+       --tokenizer-name-or-path meta-llama/Meta-Llama-3-8B \
+       --output-folder datasets/emotion \
+       --n-tasks 16 \
+       hf \
+       --dataset dair-ai/emotion \
 
-The preprocessing script has to be launched with `torchrun` in order to spawn `--nproc-per-node` workers that will preprocess the dataset concurrently. The `--input` dataset can be either a Hugging Face Dataset from the Hub or a `.json` file. The processed dataset will be stored in *`--output-prefix`_input_ids.npy*. In `--tokenizer-name-or-path`, we will have to specify a tokenizer in the same way as we do when using `AutoTokenizers.from_pretrained(...)`. +First with `--tokenizer-name-or-path` we will specify a tokenizer in the same way as we do when using `AutoTokenizers.from_pretrained(...)`. Then we specify the `--output-folder` where we will store the tokenized documents and the number of workers with `--n-tasks`. Finally we will indicate the type of dataset (whether if it's a Hugging Face Dataset ["**hf**"] or in jsonl ["**jsonl**"] format) and the dataset that we want to preprocess. Check the different settings with `python3 tools/preprocess_data.py --help`, `python3 tools/preprocess_data.py hf --help` & `python3 tools/preprocess_data.py jsonl --help`. -The output will be one file named, in this case, `datasets/testing_alpaca_small_input_ids.npy`. We will then have to specify this file in the `dataset_path` field in the config file. +Every worker will store in `--output-folder` 3 different kind of files: +- `*.ds` Containing the tokenized documents +- `*.ds.index` Containing the bounds of each tokenized document +- `*.ds.metadata` Containing the number of tokens and tokenizer used + +> [!IMPORTANT] +Remember to introduce the type of dataset to process. e.g. python3 tools/preprocess_data.py --tokenizer-name-or-path gpt2 --n-tasks 16 **jsonl** --dataset raw_datasets/c4-es-json-files ## Working with Nanosets To work with `Nanosets`, we just need to configure 1 argument: -1. `dataset_path`: This argument specifies the file or files that will compose the `Nanoset`. There are 3 ways to specify it: +1. `dataset_folder`: This argument specifies the file or files that will compose the `Nanoset`. There are 3 ways to specify it: 1. If we specify a single path, we will create a `Nanoset` from a single dataset file. ```yaml data_stages: @@ -43,7 +44,7 @@ To work with `Nanosets`, we just need to configure 1 argument: start_training_step: 1 data: dataset: - dataset_path: datasets/SlimPajama-6B_input_ids.npy + dataset_folder: datasets/SlimPajama-6B num_loading_workers: 0 seed: 1234 ``` @@ -54,9 +55,9 @@ To work with `Nanosets`, we just need to configure 1 argument: start_training_step: 15 data: dataset: - dataset_path: - - datasets/SlimPajama-6B_input_ids.npy - - datasets/testing_alpaca_small_input_ids.npy + dataset_folder: + - datasets/SlimPajama-6B + - datasets/testing_alpaca_small num_loading_workers: 0 seed: 1234 ``` @@ -67,9 +68,9 @@ To work with `Nanosets`, we just need to configure 1 argument: start_training_step: 25 data: dataset: - dataset_path: - datasets/SlimPajama-6B_input_ids.npy: 0.8 - datasets/testing_alpaca_small_input_ids.npy: 0.2 + dataset_folder: + datasets/SlimPajama-6B: 0.8 + datasets/testing_alpaca_small: 0.2 num_loading_workers: 0 seed: 1234 ``` @@ -78,11 +79,14 @@ To work with `Nanosets`, we just need to configure 1 argument: Finally, to use the `Nanosets`, launch the training with [`run_train.py`](../run_train.py). ```shell -torchrun --nproc-per-node 8 run_train.py --config configs/config_nanoset.yaml +torchrun --nproc-per-node 1 run_train.py --config examples/config_nanoset.yaml ``` ## Under the hood -`Nanosets` are responsible of building samples of `sequence length + 1` tokens from the preprocessed dataset files. The `dataset lengths` of each dataset will be determined by the `(dataset_number_of_tokens - 1) / sequence length`, discarding the last sample if its length < `sequence length`. +`Nanosets` are responsible of building samples of `sequence length + 1` tokens from the preprocessed dataset files. Despite most of the extracting logic lies in `DatatroveFolderDataset`, `Nanosets` will take care of the following: +1. Creating dataset mixtures from different dataset folder paths +2. Ensure that in each epoch, we consume each sample only once +3. Ensure that we never exhaust the `DataLoader` Based on the `dataset lengths`, the `dataset weights` and the `number of samples per epoch` (defined as the `sum(dataset lengths)`), we build the two indexes we need in order to extract samples from the `Nanoset` ([build_nanoset_index_helper](../src/nanotron/data/nanoset.py)): - `dataset index`: Contains the index of the dataset from the list of `dataset paths` from which to extract the sample, respecting the established dataset weight. diff --git a/examples/config_nanoset.yaml b/examples/config_nanoset.yaml index 31f23bf0..127ddb5e 100644 --- a/examples/config_nanoset.yaml +++ b/examples/config_nanoset.yaml @@ -7,25 +7,25 @@ checkpoints: data_stages: - data: dataset: - dataset_path: datasets/testing_alpaca_small_input_ids.npy + dataset_folder: datasets/c4-es/tokenized num_loading_workers: 1 seed: 42 name: General purpose training (Single dataset) start_training_step: 1 - data: dataset: - dataset_path: - - datasets/yelp_review_full_input_ids.npy - - datasets/testing_alpaca_small_input_ids.npy + dataset_folder: + - datasets/SlimPajama-6B/tokenized + - datasets/c4-es/tokenized num_loading_workers: 1 seed: 42 name: Second purpose training (> 1 dataset) start_training_step: 15 - data: dataset: - dataset_path: - datasets/testing_alpaca_small_input_ids.npy: 0.8 - datasets/yelp_review_full_input_ids.npy: 0.2 + dataset_folder: + datasets/SlimPajama-6B/tokenized: 0.8 + datasets/c4-es/tokenized: 0.2 num_loading_workers: 1 seed: 42 name: Third purpose training (Blended dataset) @@ -57,7 +57,7 @@ model: initializer_range: 0.02 intermediate_size: 64 is_llama_config: true - max_position_embeddings: 256 + max_position_embeddings: 1024 num_attention_heads: 4 num_hidden_layers: 2 num_key_value_heads: 4 @@ -67,7 +67,7 @@ model: rope_scaling: null tie_word_embeddings: true use_cache: true - vocab_size: 32000 + vocab_size: 50257 optimizer: accumulate_grad_in_fp32: true clip_grad: 1.0 @@ -88,11 +88,11 @@ optimizer: weight_decay: 0.01 zero_stage: 0 parallelism: - dp: 2 + dp: 1 expert_parallel_size: 1 pp: 1 pp_engine: 1f1b - tp: 2 + tp: 1 tp_linear_async_communication: true tp_mode: REDUCE_SCATTER profiler: null @@ -105,6 +105,6 @@ tokens: limit_test_batches: 0 limit_val_batches: 0 micro_batch_size: 2 - sequence_length: 128 + sequence_length: 1024 train_steps: 200 val_check_interval: -1 diff --git a/examples/mamba/README.md b/examples/mamba/README.md index 5c31d07f..8eefa9c2 100644 --- a/examples/mamba/README.md +++ b/examples/mamba/README.md @@ -18,6 +18,18 @@ pip install -r requirements.txt > https://wandb.ai/bouteille/test/reports/Mamba-loss--Vmlldzo2OTgwNDM5 +## Bug related to nanotron +Encountered the following issue when ran train_mamba.sh: +``` +causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv +``` +Solved this by doing: +pip uninstall mamba-ssm +pip install causal_conv1d==1.1.1 +pip install mamba-ssm --no-cache-dir +https://github.com/state-spaces/mamba/issues/169 + + ## Credits Credits to the following repositories from which the code was adapted: - https://github.com/state-spaces/mamba diff --git a/pyproject.toml b/pyproject.toml index e65f37a5..6a0cfb83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ fast-modeling = [ nanosets = [ "transformers", - "datasets", + "datatrove[io,processing]@git+https://github.com/huggingface/datatrove", "numba", ] diff --git a/run_train.py b/run_train.py index b33231f4..021d955d 100644 --- a/run_train.py +++ b/run_train.py @@ -143,17 +143,17 @@ def get_dataloader_from_data_stage( elif isinstance(data.dataset, NanosetDatasetsArgs): # Get tokenizer cardinality tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) - token_dtype = np.int32 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else np.uint16 + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 del tokenizer # Create Nanoset from nanotron.data.nanoset import Nanoset with main_rank_first(trainer.parallel_context.world_pg): train_dataset = Nanoset( - dataset_paths=data.dataset.dataset_path, + dataset_folders=data.dataset.dataset_folder, dataset_weights=data.dataset.dataset_weights, sequence_length=trainer.sequence_length, - token_dtype=token_dtype, + token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, random_seed=data.seed, ) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index d5b9976f..de0fa3c0 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -93,25 +93,25 @@ def __post_init__(self): @dataclass class NanosetDatasetsArgs: - dataset_path: Union[str, dict, List[str]] + dataset_folder: Union[str, dict, List[str]] def __post_init__(self): - if isinstance(self.dataset_path, str): # Case 1: 1 Dataset file - self.dataset_path = [self.dataset_path] + if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder + self.dataset_folder = [self.dataset_folder] self.dataset_weights = [1] - elif isinstance(self.dataset_path, List): # Case 2: > 1 Dataset file + elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset folder self.dataset_weights = None # Set to None so we consume all the samples randomly - elif isinstance(self.dataset_path, dict): # Case 3: dict with > 1 dataset_path and weights - tmp_dataset_path = self.dataset_path.copy() - self.dataset_path = list(tmp_dataset_path.keys()) - self.dataset_weights = list(tmp_dataset_path.values()) + elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights + tmp_dataset_folder = self.dataset_folder.copy() + self.dataset_folder = list(tmp_dataset_folder.keys()) + self.dataset_weights = list(tmp_dataset_folder.values()) @dataclass class DataArgs: """Arguments related to the data and data files processing""" - dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs] + dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]] seed: Optional[int] num_loading_workers: Optional[int] = 1 @@ -145,6 +145,7 @@ class CheckpointsArgs: checkpoints_path: Path checkpoint_interval: int save_initial_state: Optional[bool] = False + save_final_state: Optional[bool] = False resume_checkpoint_path: Optional[Path] = None checkpoints_path_is_shared_file_system: Optional[bool] = False diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 758d780f..1663f992 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -35,6 +35,8 @@ class ParallelismArgs: tp_linear_async_communication: Optional[bool] = None recompute_layer: bool = False + tp_recompute_allgather: bool = True + expert_parallel_size: int = 1 def __post_init__(self): diff --git a/src/nanotron/data/collator.py b/src/nanotron/data/collator.py new file mode 100644 index 00000000..199527e1 --- /dev/null +++ b/src/nanotron/data/collator.py @@ -0,0 +1,80 @@ +import dataclasses +from typing import Dict, List, Union + +import numpy as np +import torch +from nanotron import distributed as dist +from nanotron.parallel.context import ParallelContext +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer + + +@dataclasses.dataclass +class NanosetDataCollatorForCLM: + """ + Data collator used for causal language modeling with Nanosets dataset. + + - input_pp_rank: Discards last input id token + - output_pp_rank: Discards first label id token + - other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data. + """ + + sequence_length: int + input_pp_rank: int + output_pp_rank: int + parallel_context: ParallelContext + + def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + # Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data. + current_pp_rank = dist.get_rank(self.parallel_context.pp_pg) + if current_pp_rank not in [ + self.input_pp_rank, + self.output_pp_rank, + ]: + assert all(len(example) == 0 for example in examples) + return { + "input_ids": TensorPointer(group_rank=self.input_pp_rank), + "input_mask": TensorPointer(group_rank=self.input_pp_rank), + "label_ids": TensorPointer(group_rank=self.output_pp_rank), + "label_mask": TensorPointer(group_rank=self.output_pp_rank), + } + + # Make sure we load only what's necessary, ie we only load a `input_ids` column. + assert all(list(example.keys()) == ["input_ids"] for example in examples) + + # TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor? + input_ids = torch.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s) + batch_size, expanded_input_length = input_ids.shape + + result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {} + + result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank) + result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank) + result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank) + result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank) + + assert ( + expanded_input_length == self.sequence_length + 1 + ), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}" + + # Process inputs: last token is the label + if current_pp_rank == self.input_pp_rank: + result["input_ids"] = input_ids[:, :-1] + result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) + + # Process labels: shift them to the left + if current_pp_rank == self.output_pp_rank: + result["label_ids"] = input_ids[:, 1:] + result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) + + if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length: + raise ValueError( + f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be" + f" {self.sequence_length}." + ) + if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length: + raise ValueError( + f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be" + f" {self.sequence_length}." + ) + + return result diff --git a/src/nanotron/data/dataloader_builder.py b/src/nanotron/data/dataloader_builder.py index 4719c476..9d3285f6 100644 --- a/src/nanotron/data/dataloader_builder.py +++ b/src/nanotron/data/dataloader_builder.py @@ -1,7 +1,7 @@ import nanotron.distributed as dist from nanotron import logging +from nanotron.data.collator import NanosetDataCollatorForCLM from nanotron.dataloader import ( - DataCollatorForCLM, EmptyInfiniteDataset, get_dataloader_worker_init, get_sampler, @@ -32,7 +32,7 @@ def build_nanoset_dataloader( # No need to spawn a lot of workers, we can just use main dataloader_num_workers = 0 - data_collator = DataCollatorForCLM( + data_collator = NanosetDataCollatorForCLM( sequence_length=sequence_length, input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, diff --git a/src/nanotron/data/nanoset.py b/src/nanotron/data/nanoset.py index 9d62b33d..90200967 100644 --- a/src/nanotron/data/nanoset.py +++ b/src/nanotron/data/nanoset.py @@ -1,7 +1,10 @@ +import os +import warnings from typing import Dict, List, Tuple, Union import numpy as np import torch +from datatrove.utils.dataset import DatatroveFolderDataset from nanotron import logging from nanotron.data.utils import count_dataset_indexes, normalize from nanotron.logging import log_rank @@ -15,49 +18,60 @@ class Nanoset(torch.utils.data.Dataset): The Nanoset dataset Args: - dataset_paths (List[str]): List of paths to tokenized datasets - dataset_weights (List[float]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__ + dataset_folders (List[str]): List of folders with tokenized datasets + dataset_weights (Union[List[float], None]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__ sequence_length (int): Sequence length of the built samples - token_dtype (Union[np.uint16, np.int32]): dtype of the tokens stored in the processed dataset files. np.uin16 for vocab sizes < 65535, np.int32 otherwise + token_size (int): Number of bytes for the tokens stored in the processed dataset files. 2 for vocab sizes < 65535, 4 otherwise train_split_num_samples (int): Number of samples the dataset needs. It's the training steps * global batch size """ def __init__( self, - dataset_paths: List[str], - dataset_weights: Union[List[float], None], + dataset_folders: List[str], sequence_length: int, - token_dtype: Union[np.uint16, np.int32], + token_size: int, train_split_num_samples: int, + dataset_weights: Union[List[float], None] = None, random_seed: int = 1234, ) -> None: + # Checks + if isinstance(dataset_folders, str): + warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]") + dataset_folders = [dataset_folders] + # Init - self.dataset_paths = dataset_paths - self.dataset_weights = dataset_weights + self.dataset_folders = dataset_folders self.sequence_length = sequence_length - self.token_dtype = token_dtype + self.token_size = token_size self.train_split_num_samples = train_split_num_samples self.random_seed = random_seed + self.datatrove_datasets = [] + for dataset_folder in self.dataset_folders: + self.datatrove_datasets.append( + DatatroveFolderDataset( + folder_path=dataset_folder, + filename_pattern=os.path.join(dataset_folder, "*.ds"), + seq_len=sequence_length, + recursive=False, + token_size=token_size, + shuffle=True, + ) + ) # Build Nanoset Index ## To build the index we need the length of each dataset - self.dataset_lengths = [] - for dataset_path in self.dataset_paths: - self.dataset_buffer_mmap = np.memmap(dataset_path, mode="r", order="C", dtype=self.token_dtype) - self.dataset_buffer = memoryview(self.dataset_buffer_mmap) - dataset_number_of_tokens = int(len(self.dataset_buffer)) - number_of_samples = int( - (dataset_number_of_tokens - 1) / sequence_length - ) # Discard last sample if length < sequence_length - self.dataset_lengths.append(number_of_samples) + self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets] ## Set dataset weights if ( - self.dataset_weights is None + dataset_weights is None ): # Case of training with > 1 datasets without weighting them: Consume both datasets entirely on each epoch self.dataset_weights = normalize(self.dataset_lengths) else: self.dataset_weights = normalize(dataset_weights) + assert len(dataset_folders) == len( + self.dataset_weights + ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index self.dataset_index, self.dataset_sample_index = self.build_nanoset_index() @@ -79,25 +93,12 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: idx (int): The index into the dataset Returns: - Dict[str, numpy.ndarray]: The input ids wrapped in a dictionary + Dict[str, torch.LongTensor]: The input ids wrapped in a dictionary """ - dataset = self.dataset_index[idx] dataset_sample = self.dataset_sample_index[idx] - # Rebuild the memmap in every access to free memory - # https://stackoverflow.com/a/61472122 - self.dataset_buffer_mmap = np.memmap(self.dataset_paths[dataset], mode="r", order="C", dtype=self.token_dtype) - self.dataset_buffer = memoryview(self.dataset_buffer_mmap) - - # uint16 -> 2 bytes per token, int32 -> 4 bytes per token - offset = dataset_sample * self.sequence_length * (np.iinfo(self.token_dtype).bits / 8) - input_ids_tokens = np.frombuffer( - self.dataset_buffer, dtype=self.token_dtype, count=(self.sequence_length + 1), offset=int(offset) - ) - - # Return tokens as np.int32 as Torch can't handle uint16 - return {"input_ids": input_ids_tokens.astype(np.int32)} + return self.datatrove_datasets[dataset][dataset_sample] def build_nanoset_index(self) -> np.ndarray: """ @@ -124,15 +125,6 @@ def build_nanoset_index(self) -> np.ndarray: return dataset_index, dataset_sample_index - def __del__(self) -> None: - """ - Clean up Nanoset - """ - - if hasattr(self, "dataset_buffer_mmap"): - self.dataset_buffer_mmap._mmap.close() - del self.dataset_buffer_mmap - def print_nanoset_info(self): log_rank(f"> Total number of samples: {len(self)}", logger=logger, level=logging.INFO, rank=0) @@ -141,10 +133,10 @@ def print_nanoset_info(self): ) # Print samples from each dataset + weight - dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_paths)) + dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders)) for index, sample_count in enumerate(dataset_sample_count): log_rank( - f"> Total number of samples from the {self.dataset_paths[index].rsplit('/', 1)[-1]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})", + f"> Total number of samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})", logger=logger, level=logging.INFO, rank=0, diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 8c1f7bc3..5bdaaf4b 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -233,6 +233,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) self.down_proj = TensorParallelRowLinear( config.intermediate_size, @@ -242,8 +243,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, ) - # TODO @nouamane: why can't we torch.jit.script GLUActivation? - self.split_silu_mul = GLUActivation(config.hidden_act) + self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act)) def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states) @@ -422,6 +422,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. if config.rope_interleaved: @@ -912,6 +913,7 @@ def __init__( # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, "async_communication": tp_linear_async_communication, + "tp_recompute_allgather": parallel_config.tp_recompute_allgather, }, module_input_keys={"x"}, module_output_keys={"logits"}, diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 873d77df..bd41347a 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -85,7 +85,8 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): group = ctx.group - return DifferentiableReduceScatterSum.apply(grad_output, group), None + out = DifferentiableReduceScatterSum.apply(grad_output, group) + return out, None class DifferentiableReduceScatterSum(torch.autograd.Function): @@ -113,7 +114,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): *rest_size, device=tensor.device, dtype=tensor.dtype, - requires_grad=tensor.requires_grad, + requires_grad=False, ) dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM) return sharded_tensor diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index fdef48ac..e2ee3a29 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -20,13 +20,12 @@ import nanotron.distributed as dist from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( - differentiable_all_gather, differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 +from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 class _ShardedCrossEntropy(torch.autograd.Function): @@ -89,10 +88,10 @@ def forward( @staticmethod def backward(ctx, grad_output): - # Retreive tensors from the forward path. + # Retrieve tensors from the forward path. softmax, target_mask, masked_target_1d = ctx.saved_tensors - # All the inputs have softmax as thier gradient. + # All the inputs have softmax as their gradient. grad_input = softmax # For simplicity, work with the 2D gradient. sharded_hidden_size = softmax.size()[-1] @@ -121,10 +120,12 @@ class _ColumnLinearAsyncCommunication(torch.autograd.Function): @staticmethod @assert_cuda_max_connections_set_to_1 - def forward(ctx, tensor, weight, bias, group, tp_mode): + def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather): ctx.use_bias = bias is not None ctx.tp_mode = tp_mode ctx.group = group + ctx.tp_recompute_allgather = tp_recompute_allgather + ctx.tensor_shape = tensor.size() if tp_mode is TensorParallelLinearMode.ALL_REDUCE: gathered_tensor = tensor @@ -141,7 +142,7 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): # `tensor` can sometimes not be contiguous # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317 tensor = tensor.contiguous() - ctx.save_for_backward(tensor, weight) + # ctx.save_for_backward(tensor, weight) # TODO @thomasw21: gather along another dimension sharded_batch_size, *intermediate_size, hidden_size = tensor.shape @@ -149,14 +150,19 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): group = dist.distributed_c10d._get_default_group() gathered_batch_size = sharded_batch_size * group.size() - gathered_tensor = torch.empty( - gathered_batch_size, - *intermediate_size, - hidden_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=tensor.requires_grad, - ) + if tp_recompute_allgather: + gathered_tensor = MemoryBuffer().get( + "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype + ) + else: + gathered_tensor = torch.empty( + gathered_batch_size, + *intermediate_size, + hidden_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=False, + ) handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True) @@ -204,6 +210,10 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): # Wait communication handle.wait() + if tp_recompute_allgather: + ctx.save_for_backward(tensor, weight) + else: + ctx.save_for_backward(gathered_tensor, weight) # Compute all the other shards that are obtained from AllGather # weights: w0 w1 w2 w3 @@ -261,8 +271,8 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias tp_mode = ctx.tp_mode - handle: Optional[dist.Work] = None - if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + handle1: Optional[dist.Work] = None + if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER and ctx.tp_recompute_allgather: # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = tensor.shape if group is None: @@ -273,14 +283,10 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = torch.empty( - unsharded_batch_size, - *rest_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=False, + unsharded_tensor = MemoryBuffer().get( + "allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype ) - handle = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) + handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # gather is scheduled before the tensor gradient computation total_tensor = unsharded_tensor @@ -289,9 +295,6 @@ def backward(ctx, grad_output): grad_tensor = grad_output.matmul(weight) - if handle is not None: - handle.wait() - # Doing gather + slicing during the NeMo forward pass can make this tensor # not be contiguous. PyTorch only checks if the tensor is contiguous, and only # clones it if it's not contiguous: @@ -303,41 +306,128 @@ def backward(ctx, grad_output): grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim) - handle: Optional[dist.Work] = None + handle2: Optional[dist.Work] = None if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: if group.size() == 1: sub_grad_tensor = grad_tensor else: sub_grad_tensor = torch.empty( - tensor.shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False + ctx.tensor_shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False ) # reduce_scatter - handle = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) + handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # reduce scatter is scheduled before the weight gradient computation elif tp_mode is TensorParallelLinearMode.ALL_REDUCE: # Asynchronous all-reduce - handle = dist.all_reduce(grad_tensor, group=group, async_op=True) + handle2 = dist.all_reduce(grad_tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # all-reduce is scheduled before the weight gradient computation else: raise ValueError() + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if handle1 is not None: + handle1.wait() + # TODO @thomasw21: This sounds like we don't have the optimal physical layout grad_weight = grad_output.t().matmul(total_tensor) - grad_bias = grad_output.sum(dim=0) if use_bias else None - if handle is not None: - handle.wait() + if handle2 is not None: + handle2.wait() if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return sub_grad_tensor, grad_weight, grad_bias, None, None + return sub_grad_tensor, grad_weight, grad_bias, None, None, None elif tp_mode is TensorParallelLinearMode.ALL_REDUCE: - return grad_tensor, grad_weight, grad_bias, None, None + return grad_tensor, grad_weight, grad_bias, None, None, None else: raise ValueError(f"Got unexpected mode: {tp_mode}.") +class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function): + """ + Column linear with memory_buffer for the allgather, context parallel + enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and + async communication disabled. + """ + + @staticmethod + def forward( + ctx, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + group: dist.ProcessGroup, + tp_recompute_allgather: bool, + ): + + # Do allgather. + sharded_batch_size, *rest_size = input.shape + unsharded_batch_size = sharded_batch_size * group.size() + if group.size() == 1: + total_input = input.contiguous() + elif tp_recompute_allgather: + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + else: + total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + + # Prepare context. + ctx.group = group + ctx.tp_recompute_allgather = tp_recompute_allgather + ctx.input_size = input.shape + if tp_recompute_allgather: + ctx.save_for_backward(input, weight, bias) + else: + ctx.save_for_backward(total_input, weight, bias) + + # Get linear output. + out = F.linear(total_input, weight, bias) + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + # Either allgather the inputs again or get them from context. + group = ctx.group + tp_recompute_allgather = ctx.tp_recompute_allgather + input_size = ctx.input_size + if group.size() == 1 or not tp_recompute_allgather: + total_input, weight, bias = ctx.saved_tensors + else: + input, weight, bias = ctx.saved_tensors + sharded_batch_size, *rest_size = input.shape + total_input = sharded_batch_size * group.size() + unsharded_batch_size = sharded_batch_size * group.size() + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.contiguous() + grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1] + total_input_first_dims, total_input_last_dim = total_input.shape[:-1], total_input.shape[-1] + grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) + total_input = total_input.view(math.prod(total_input_first_dims), total_input_last_dim) + + # Compute gradients. + grad_weight = grad_output.T @ total_input + grad_input = grad_output @ weight + if group.size() == 1: + sub_grad_input = grad_input + else: + # Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 + # We set grad_input to be contiguous in case it isn't already. + grad_input = grad_input.contiguous() + sub_grad_input = torch.empty( + input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False + ) + dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) + grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None + + return sub_grad_input, grad_weight, grad_bias, None, None + + def column_linear( input: torch.Tensor, weight: torch.Tensor, @@ -345,18 +435,19 @@ def column_linear( group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, async_communication: bool, + tp_recompute_allgather: bool = True, ): if async_communication: - return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) + return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) - elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - input = differentiable_all_gather(input, group=group) - else: - raise ValueError(f"Got unexpected mode: {tp_mode}.") - - return F.linear(input, weight, bias) + return F.linear(input, weight, bias) + if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( + input, weight, bias, group, tp_recompute_allgather + ) + raise ValueError(f"Got unexpected mode: {tp_mode}.") class _RowLinearAsyncCommunication(torch.autograd.Function): @@ -387,8 +478,7 @@ def backward(ctx, grad_output): group = ctx.group use_bias = ctx.use_bias - handle_0: Optional[dist.Work] = None - handle_1: Optional[dist.Work] = None + handle: Optional[dist.Work] = None # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = grad_output.shape @@ -398,12 +488,8 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_output = torch.empty( - unsharded_batch_size, - *rest_size, - device=grad_output.device, - dtype=grad_output.dtype, - requires_grad=False, + total_grad_output = MemoryBuffer().get( + "allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype ) # Doing gather + slicing during the NeMo forward pass can make this tensor @@ -412,31 +498,69 @@ def backward(ctx, grad_output): # https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761 grad_output = grad_output.contiguous() - handle_0 = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True) - - grad_tensor = grad_output.matmul(weight) - - # wait for the first all_gather to finish before starting the second all_gather - if handle_0 is not None: - handle_0.wait() + handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True) - # TODO @thomasw21: gather along another dimension - sharded_batch_size, *rest_size = grad_tensor.shape + # total_grad_output: [b, s, h_out] + # weight: [h_out, h_in/n] + # total_grad_tensor: [b, s, h_in/n] + # grad_output: [b/n, s, h_out] + sharded_batch_size, *rest_size_grad_output = grad_output.shape + rest_size_grad_tensor = rest_size_grad_output[:-1] + [weight.shape[1]] if group.size() == 1: - total_grad_tensor = grad_tensor + total_grad_tensor = grad_output.matmul(weight) else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_tensor = torch.empty( unsharded_batch_size, - *rest_size, - device=grad_tensor.device, - dtype=grad_tensor.dtype, + *rest_size_grad_tensor, + device=grad_output.device, + dtype=grad_output.dtype, requires_grad=False, ) + before_shard_grad_tensor, same_device_shard_grad_tensor, after_shard_grad_tensor = torch.split( + total_grad_tensor, + split_size_or_sections=[ + sharded_batch_size * dist.get_rank(group), + sharded_batch_size, + sharded_batch_size * (group.size() - dist.get_rank(group) - 1), + ], + dim=0, + ) + # compute local shard + torch.mm( + input=grad_output.view(-1, grad_output.shape[-1]), + mat2=weight, + out=same_device_shard_grad_tensor.view(-1, weight.shape[1]), + ) - handle_1 = dist.all_gather_into_tensor(total_grad_tensor, grad_tensor, group=group, async_op=True) + if handle is not None: + handle.wait() + + before_shard_grad_output, _, after_shard_grad_output = torch.split( + total_grad_output, + split_size_or_sections=[ + sharded_batch_size * dist.get_rank(group), + sharded_batch_size, + sharded_batch_size * (group.size() - dist.get_rank(group) - 1), + ], + dim=0, + ) + + # before shard compute + if before_shard_grad_tensor.numel() > 0: + torch.mm( + input=before_shard_grad_output.view(-1, before_shard_grad_output.shape[-1]), + mat2=weight, + out=before_shard_grad_tensor.view(-1, weight.shape[1]), + ) + # after shard compute + if after_shard_grad_tensor.numel() > 0: + torch.mm( + input=after_shard_grad_output.view(-1, after_shard_grad_output.shape[-1]), + mat2=weight, + out=after_shard_grad_tensor.view(-1, weight.shape[1]), + ) # Convert the tensor shapes to 2D for execution compatibility tensor = tensor.contiguous() @@ -454,9 +578,6 @@ def backward(ctx, grad_output): grad_weight = total_grad_output.t().matmul(tensor) grad_bias = total_grad_output.sum(dim=0) if use_bias else None - if handle_1 is not None: - handle_1.wait() - return total_grad_tensor, grad_weight, grad_bias, None, None diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 40e89968..4c7325cd 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -51,6 +51,7 @@ def __init__( dtype=None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, + tp_recompute_allgather: bool = True, ): self.pg = pg self.world_size = pg.size() @@ -59,6 +60,7 @@ def __init__( self.in_features = in_features self.out_features = out_features // self.world_size + self.tp_recompute_allgather = tp_recompute_allgather super().__init__( in_features=self.in_features, @@ -91,6 +93,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, + tp_recompute_allgather=self.tp_recompute_allgather, ) def extra_repr(self) -> str: diff --git a/src/nanotron/parallel/utils.py b/src/nanotron/parallel/utils.py index e0159902..156a2a86 100644 --- a/src/nanotron/parallel/utils.py +++ b/src/nanotron/parallel/utils.py @@ -1,11 +1,31 @@ import functools +import operator import os +import torch from torch import nn from nanotron import distributed as dist from nanotron.parallel import ParallelContext from nanotron.parallel.tied_parameters import get_tied_id_to_param +from nanotron.utils import Singleton + + +class MemoryBuffer(metaclass=Singleton): + """ + Global memory buffer to store intermediate activations that need not to be cached for the backward pass. + """ + + def __init__(self): + self.buffer = {} + + def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: + required_numel = functools.reduce(operator.mul, shape, 1) + if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel: + self.buffer[name, dtype] = torch.empty( + required_numel, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False + ) + return self.buffer[name, dtype][:required_numel].view(shape) def assert_cuda_max_connections_set_to_1(func): diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index aceccc3a..76461481 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -442,7 +442,10 @@ def train( if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0: self.save_checkpoint() dist.barrier() # let's wait for everyone before leaving - + + if self.config.checkpoints.save_final_state: + self.save_checkpoint() + self.post_training() def training_step( diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index 4a2edc49..cd0656ec 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -14,6 +14,25 @@ from nanotron import distributed as dist +class Singleton(type): + """ + Singleton metaclass. + Create objects using this class as the metaclass to enable singleton behaviour. + For instance: + ``` + class Logger(metaclass=Singleton): + ... + ``` + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + class ContextManagers: """ Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` diff --git a/tests/helpers/data.py b/tests/helpers/data.py index 33bb2480..72deb7f5 100644 --- a/tests/helpers/data.py +++ b/tests/helpers/data.py @@ -3,6 +3,7 @@ import json import os import sys +from argparse import Namespace from collections import OrderedDict from pathlib import Path @@ -10,8 +11,6 @@ package_path = Path(package.__file__).parent.parent.parent sys.path.append(str(package_path)) -from argparse import Namespace - import nanotron.distributed as dist import torch from nanotron.data.nanoset import Nanoset @@ -23,31 +22,34 @@ def create_dataset_paths(tmp_dir: str, quantity: int): - json_dataset_path = [os.path.join(tmp_dir, f"pytest_{i}") for i in range(quantity)] - mmap_dataset_path = [f"{path}_input_ids.npy" for path in json_dataset_path] + json_dataset_path = [os.path.join(tmp_dir, f"pytest_{i}.json") for i in range(quantity)] + datatrove_tokenized_dataset_paths = [os.path.join(tmp_dir, f"tokenized_documents_{i}") for i in range(quantity)] - return json_dataset_path, mmap_dataset_path + return json_dataset_path, datatrove_tokenized_dataset_paths def create_dummy_json_dataset(path_to_json: str, dummy_text: str, n_samples: int = 50000): - with open(path_to_json + ".json", "a") as json_file: + with open(path_to_json, "a") as json_file: for sample in range(n_samples): sample_dict = {"text": f"[{sample}] Hello! Im sample {sample}! And this is my dummy text: {dummy_text}"} json_file.write(json.dumps(sample_dict)) json_file.write("\n") -def preprocess_dummy_dataset(path_to_json: str, tokenizer: str): +def preprocess_dummy_dataset(json_dataset_path: str, datatrove_tokenized_dataset_path: str, tokenizer: str): # Create args for preprocessing args = Namespace( - input=path_to_json + ".json", + readers="jsonl", + dataset=json_dataset_path, column="text", - output_prefix=path_to_json, + glob_pattern=None, + output_folder=datatrove_tokenized_dataset_path, tokenizer_name_or_path=tokenizer, - add_special_tokens=False, + eos_token=None, + n_tasks=1, + logging_dir=None, ) - # tools/preprocess_data.py main main(args) @@ -122,7 +124,7 @@ def assert_nanoset_sync_across_all_ranks(nanoset: Nanoset, parallel_context: Par IDX_SAMPLE = 23 nanoset_identifiers = OrderedDict() - nanoset_identifiers["dataset_paths"] = nanoset.dataset_paths + nanoset_identifiers["dataset_folders"] = nanoset.dataset_folders nanoset_identifiers["dataset_weights"] = nanoset.dataset_weights.tolist() nanoset_identifiers["sequence_length"] = nanoset.sequence_length nanoset_identifiers["train_split_num_samples"] = nanoset.train_split_num_samples @@ -131,6 +133,7 @@ def assert_nanoset_sync_across_all_ranks(nanoset: Nanoset, parallel_context: Par nanoset_identifiers["input_ids"] = nanoset[IDX_SAMPLE]["input_ids"].tolist() nanoset_identifiers["dataset_index"] = nanoset.dataset_index.tolist() nanoset_identifiers["dataset_sample_index"] = nanoset.dataset_sample_index.tolist() + nanoset_identifiers["token_size"] = nanoset.token_size unique_description_hash = compute_hash(nanoset_identifiers) assert_tensor_synced_across_pg( diff --git a/tests/nanoset/test_build_nanoset_dataloader.py b/tests/nanoset/test_build_nanoset_dataloader.py index 331e4f64..5a48cb9c 100644 --- a/tests/nanoset/test_build_nanoset_dataloader.py +++ b/tests/nanoset/test_build_nanoset_dataloader.py @@ -1,6 +1,7 @@ import sys from math import isclose from pathlib import Path +from typing import List import numpy as np import pytest @@ -33,7 +34,7 @@ for all_4d_configs in get_all_4d_configurations(gpus) ], ) -@pytest.mark.parametrize("train_steps", [5, 100]) +@pytest.mark.parametrize("train_steps", [500, 10000]) @pytest.mark.parametrize("sequence_length", [512, 8192]) @pytest.mark.parametrize("tokenizer_name_or_path", ["openai-community/gpt2", "unsloth/llama-3-8b-bnb-4bit"]) @rerun_if_address_is_in_use() @@ -42,16 +43,21 @@ def test_build_nanoset_dataloader( ): test_context = TestContext() - # Create dataset files - json_paths, mmap_dataset_paths = create_dataset_paths(tmp_dir=test_context.get_auto_remove_tmp_dir(), quantity=2) + # Create dataset folders + json_paths, datatrove_tokenized_dataset_folders = create_dataset_paths( + tmp_dir=test_context.get_auto_remove_tmp_dir(), quantity=2 + ) # Create dummy json datasets for idx, json_path in enumerate(json_paths): create_dummy_json_dataset(path_to_json=json_path, dummy_text=f"Nanoset {idx}!", n_samples=(idx + 1) * 50000) + # Preprocess json dataset with datatrove + for json_path, datatrove_tokenized_dataset_folder in zip(json_paths, datatrove_tokenized_dataset_folders): + preprocess_dummy_dataset(json_path, datatrove_tokenized_dataset_folder, tokenizer_name_or_path) + init_distributed(tp=tp, dp=dp, pp=pp, sp=sp)(_test_build_nanoset_dataloader)( - json_paths=json_paths, - path_to_mmap_files=mmap_dataset_paths, + datatrove_tokenized_dataset_folders=datatrove_tokenized_dataset_folders, train_steps=train_steps, sequence_length=sequence_length, tokenizer_name_or_path=tokenizer_name_or_path, @@ -60,8 +66,7 @@ def test_build_nanoset_dataloader( def _test_build_nanoset_dataloader( parallel_context: ParallelContext, - json_paths: str, - path_to_mmap_files: str, + datatrove_tokenized_dataset_folders: List[str], train_steps: int, sequence_length: int, tokenizer_name_or_path: str, @@ -71,41 +76,37 @@ def _test_build_nanoset_dataloader( N_MICRO_BATCHES_PER_BATCH = 8 GLOBAL_BATCH_SIZE = MICRO_BATCH_SIZE * N_MICRO_BATCHES_PER_BATCH * parallel_context.dp_pg.size() - # Preprocess dummy json datasets - for json_path in json_paths: - preprocess_dummy_dataset(path_to_json=json_path, tokenizer=tokenizer_name_or_path) - input_pp_rank, output_pp_rank = 0, int(parallel_context.pp_pg.size() - 1) # Get tokenizer cardinality tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) - token_dtype = np.int32 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else np.uint16 + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 del tokenizer # Create Nanoset configs: 1. Normal 2. Blended 3. Blended with weights nanoset_config = { - "dataset_paths": [path_to_mmap_files[0]], + "dataset_folders": [datatrove_tokenized_dataset_folders[0]], "dataset_weights": [1], "sequence_length": sequence_length, - "token_dtype": token_dtype, + "token_size": token_size, "train_split_num_samples": train_steps * GLOBAL_BATCH_SIZE, "random_seed": SEED, } blended_nanoset_config = { - "dataset_paths": [path_to_mmap_files[0], path_to_mmap_files[1]], + "dataset_folders": datatrove_tokenized_dataset_folders, "dataset_weights": None, "sequence_length": sequence_length, - "token_dtype": token_dtype, + "token_size": token_size, "train_split_num_samples": train_steps * GLOBAL_BATCH_SIZE, "random_seed": SEED, } blended_weighted_nanoset_config = { - "dataset_paths": [path_to_mmap_files[0], path_to_mmap_files[1]], + "dataset_folders": datatrove_tokenized_dataset_folders, "dataset_weights": [8, 2], "sequence_length": sequence_length, - "token_dtype": token_dtype, + "token_size": token_size, "train_split_num_samples": train_steps * GLOBAL_BATCH_SIZE, "random_seed": SEED, } @@ -119,7 +120,7 @@ def _test_build_nanoset_dataloader( # Assert we have the same Nanoset in all ranks assert_nanoset_sync_across_all_ranks(train_dataset, parallel_context) - dataset_sample_count = count_dataset_indexes(train_dataset.dataset_index, len(train_dataset.dataset_paths)) + dataset_sample_count = count_dataset_indexes(train_dataset.dataset_index, len(train_dataset.dataset_folders)) for idx, ds_length in enumerate(train_dataset.dataset_lengths): # Assert Nanoset doesn't sample indexes greater than the datasets assert ( @@ -129,7 +130,7 @@ def _test_build_nanoset_dataloader( # Assert Nanoset builds up the correct blend WRT the dataset_weights assert isclose( normalize(dataset_sample_count).tolist()[idx], train_dataset.dataset_weights[idx], abs_tol=0.05 - ), f"Requested Nanoset to contain {round(train_dataset.dataset_weights[idx]*100, 2)}% of samples from {train_dataset.dataset_paths[idx]} but got {round(normalize(dataset_sample_count).tolist()[idx]*100, 2)}%" + ), f"Requested Nanoset to contain {round(train_dataset.dataset_weights[idx]*100, 2)}% of samples from {train_dataset.dataset_folders[idx]} but got {round(normalize(dataset_sample_count).tolist()[idx]*100, 2)}%" # Create Dataloaders dataloader = build_nanoset_dataloader( train_dataset, @@ -162,7 +163,7 @@ def _test_build_nanoset_dataloader( for all_4d_configs in get_all_4d_configurations(gpus) ], ) -@pytest.mark.parametrize("skipped_batches", [20, 50]) +@pytest.mark.parametrize("skipped_batches", [20, 5555]) @pytest.mark.parametrize("tokenizer_name_or_path", ["openai-community/gpt2", "unsloth/llama-3-8b-bnb-4bit"]) @rerun_if_address_is_in_use() def test_recover_nanoset_dataloader( @@ -170,16 +171,21 @@ def test_recover_nanoset_dataloader( ): test_context = TestContext() - # Create dataset files - json_paths, mmap_dataset_paths = create_dataset_paths(tmp_dir=test_context.get_auto_remove_tmp_dir(), quantity=2) + # Create dataset folders + json_paths, datatrove_tokenized_dataset_folders = create_dataset_paths( + tmp_dir=test_context.get_auto_remove_tmp_dir(), quantity=2 + ) # Create dummy json datasets for idx, json_path in enumerate(json_paths): create_dummy_json_dataset(path_to_json=json_path, dummy_text=f"Nanoset {idx}!", n_samples=(idx + 1) * 50000) + # Preprocess json dataset with datatrove + for json_path, datatrove_tokenized_dataset_folder in zip(json_paths, datatrove_tokenized_dataset_folders): + preprocess_dummy_dataset(json_path, datatrove_tokenized_dataset_folder, tokenizer_name_or_path) + init_distributed(tp=tp, dp=dp, pp=pp, sp=sp)(_test_recover_nanoset_dataloader)( - json_paths=json_paths, - path_to_mmap_files=mmap_dataset_paths, + datatrove_tokenized_dataset_folders=datatrove_tokenized_dataset_folders, skipped_batches=skipped_batches, tokenizer_name_or_path=tokenizer_name_or_path, ) @@ -187,8 +193,7 @@ def test_recover_nanoset_dataloader( def _test_recover_nanoset_dataloader( parallel_context: ParallelContext, - json_paths: str, - path_to_mmap_files: str, + datatrove_tokenized_dataset_folders: List[str], skipped_batches: int, tokenizer_name_or_path: str, ): @@ -197,43 +202,39 @@ def _test_recover_nanoset_dataloader( N_MICRO_BATCHES_PER_BATCH = 8 GLOBAL_BATCH_SIZE = MICRO_BATCH_SIZE * N_MICRO_BATCHES_PER_BATCH * parallel_context.dp_pg.size() SEQUENCE_LENGTH = 1024 - TRAIN_STEPS = 100 - - # Preprocess dummy json datasets - for json_path in json_paths: - preprocess_dummy_dataset(path_to_json=json_path, tokenizer=tokenizer_name_or_path) + TRAIN_STEPS = 10000 input_pp_rank, output_pp_rank = 0, int(parallel_context.pp_pg.size() - 1) # Get tokenizer cardinality tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) - token_dtype = np.int32 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else np.uint16 + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 del tokenizer # Create Nanoset configs: 1. Normal 2. Blended 3. Blended with weights nanoset_config = { - "dataset_paths": [path_to_mmap_files[0]], + "dataset_folders": [datatrove_tokenized_dataset_folders[0]], "dataset_weights": [1], "sequence_length": SEQUENCE_LENGTH, - "token_dtype": token_dtype, + "token_size": token_size, "train_split_num_samples": TRAIN_STEPS * GLOBAL_BATCH_SIZE, "random_seed": SEED, } blended_nanoset_config = { - "dataset_paths": [path_to_mmap_files[0], path_to_mmap_files[1]], + "dataset_folders": datatrove_tokenized_dataset_folders, "dataset_weights": None, "sequence_length": SEQUENCE_LENGTH, - "token_dtype": token_dtype, + "token_size": token_size, "train_split_num_samples": TRAIN_STEPS * GLOBAL_BATCH_SIZE, "random_seed": SEED, } blended_weighted_nanoset_config = { - "dataset_paths": [path_to_mmap_files[0], path_to_mmap_files[1]], + "dataset_folders": datatrove_tokenized_dataset_folders, "dataset_weights": [8, 2], "sequence_length": SEQUENCE_LENGTH, - "token_dtype": token_dtype, + "token_size": token_size, "train_split_num_samples": TRAIN_STEPS * GLOBAL_BATCH_SIZE, "random_seed": SEED, } diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 261e6a4f..139546b7 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -18,19 +18,31 @@ @pytest.mark.parametrize("tp,dp,pp,sp", [pytest.param(i, 1, 1, 1) for i in range(3, min(4, available_gpus()) + 1)]) @pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) @pytest.mark.parametrize("async_communication", [False, True]) +@pytest.mark.parametrize("tp_recompute_allgather", [False, True]) @rerun_if_address_is_in_use() def test_column_linear( - tp: int, dp: int, pp: int, sp: int, tp_mode: TensorParallelLinearMode, async_communication: bool + tp: int, + dp: int, + pp: int, + sp: int, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, ): if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: pytest.skip("ALL_REDUCE mode does not support async communication") + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather: + pytest.skip("ALL_REDUCE mode is unaffected by tp_recompute_allgather") init_distributed(tp=tp, dp=dp, pp=pp, sp=sp)(_test_column_linear)( - tp_mode=tp_mode, async_communication=async_communication + tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather ) def _test_column_linear( - parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool + parallel_context: ParallelContext, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, ): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" @@ -46,6 +58,7 @@ def _test_column_linear( mode=tp_mode, device="cuda", async_communication=async_communication, + tp_recompute_allgather=tp_recompute_allgather, ) # Un-sharded @@ -152,17 +165,41 @@ def _test_column_linear( @pytest.mark.parametrize("tp,dp,pp,sp", [pytest.param(i, 1, 1, 1) for i in range(1, min(4, available_gpus()) + 1)]) @pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) @pytest.mark.parametrize("async_communication", [False, True]) +@pytest.mark.parametrize("tp_recompute_allgather", [False, True]) @rerun_if_address_is_in_use() +<<<<<<< HEAD def test_row_linear(tp: int, dp: int, pp: int, sp: int, tp_mode: TensorParallelLinearMode, async_communication: bool): +======= +def test_row_linear( + tp: int, + dp: int, + pp: int, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): +>>>>>>> origin/main if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: pytest.skip("ALL_REDUCE mode does not support async communication") + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather: + pytest.skip("ALL_REDUCE mode is not affected by tp_recompute_allgather") +<<<<<<< HEAD init_distributed(tp=tp, dp=dp, pp=pp, sp=sp)(_test_row_linear)( tp_mode=tp_mode, async_communication=async_communication +======= + init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)( + tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather +>>>>>>> origin/main ) -def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool): +def _test_row_linear( + parallel_context: ParallelContext, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" out_features = 3 @@ -212,14 +249,19 @@ def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelL random_input = torch.randn(batch_size, in_features, device="cuda") # synchronize random_input across tp dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg) - + random_input.requires_grad = True # Row linear receives as input sharded input - random_sharded_input = random_input[ - :, - dist.get_rank(parallel_context.tp_pg) - * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) - * in_features_per_rank, - ] + random_sharded_input = ( + random_input[ + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, + ] + .detach() + .clone() + ) + random_sharded_input.requires_grad = True # Test that we get the same output after forward pass # TODO @kunhao: We may want to have our custom error type @@ -265,6 +307,16 @@ def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelL else: assert row_linear.bias is None + torch.testing.assert_close( + random_sharded_input.grad, + random_input.grad[ + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, + ], + ) + parallel_context.destroy() diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index 465d22f0..f3cdab70 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -1,26 +1,21 @@ -import argparse -import os -import shutil -import sys +""" +To process HuggingFace Datasets: + python3 tools/preprocess_data.py --tokenizer-name-or-path meta-llama/Meta-Llama-3-8B --output-folder datasets/emotion --n-tasks 16 hf --dataset dair-ai/emotion +To process Jsonl files: + python3 tools/preprocess_data.py --tokenizer-name-or-path meta-llama/Meta-Llama-3-8B --output-folder datasets/c4-es --n-tasks 16 jsonl --dataset raw_datasets/c4-es-json-files +""" -import numpy as np -import torch.distributed as dist -from tqdm import tqdm -from transformers import AutoTokenizer +import argparse -from datasets import concatenate_datasets, load_dataset +from datatrove.executor.local import LocalPipelineExecutor +from datatrove.pipeline.readers import HuggingFaceDatasetReader, JsonlReader +from datatrove.pipeline.tokens import DocumentTokenizer def get_args(): parser = argparse.ArgumentParser() - group = parser.add_argument_group(title="input data") - group.add_argument( - "--input", type=str, required=True, help="Path to local stored dataset or repository on the Hugging Face hub" - ) - group.add_argument("--column", type=str, default="text", help="Column to preprocess from the Dataset") - parser.add_argument("--split", type=str, default="train", help="Which split of the data to process") - group = parser.add_argument_group(title="tokenizer") + group = parser.add_argument_group(title="Tokenizer") group.add_argument( "--tokenizer-name-or-path", type=str, @@ -28,13 +23,54 @@ def get_args(): help="A path to a directory containing vocabulary files required by the tokenizer or the model id of a predefined tokenizer hosted inside a model repo on the Hugging Face Hub.", ) group.add_argument( - "--add-special-tokens", - action="store_true", - help="Whether or not to add special tokens when encoding the sequences. This will be passed to the Tokenizer", + "--eos-token", + type=str, + default=None, + help="EOS token to add after each document. Default: None", + ) + + group = parser.add_argument_group(title="Output data") + group.add_argument( + "--output-folder", type=str, required=True, help="Path to the output folder to store the tokenized documents" + ) + group = parser.add_argument_group(title="Miscellaneous configs") + group.add_argument( + "--logging-dir", + type=str, + default=None, + help="Path to a folder for storing the logs of the preprocessing step. Default: None", + ) + group.add_argument( + "--n-tasks", type=int, default=8, help="Total number of tasks to run the preprocessing step. Default: 8" + ) + # Subparsers for processing either Hugging Face datasets or jsonl files + sp = parser.add_subparsers( + dest="readers", + required=True, + description="Type of dataset to process. It can be either a Hugging Face Dataset loaded with datasets.load_data ('hf') or a .jsonl dataset ('jsonl')", + ) + + p1 = sp.add_parser(name="hf") + p1.add_argument( + "--dataset", + type=str, + required=True, + help="Path to local stored dataset or repository on the Hugging Face hub that can be loaded with datasets.load_dataset", ) + p1.add_argument("--column", type=str, default="text", help="Column to preprocess from the Dataset. Default: text") + p1.add_argument("--split", type=str, default="train", help="Which split of the data to process. Default: train") - group = parser.add_argument_group(title="output data") - group.add_argument("--output-prefix", type=str, required=True, help="Path to the output processed dataset file") + p2 = sp.add_parser(name="jsonl") + p2.add_argument( + "--dataset", + type=str, + required=True, + help="Path to a .jsonl file or a folder containing multiple .jsonl files", + ) + p2.add_argument("--column", type=str, default="text", help="Column to preprocess from the Dataset. Default: text") + p2.add_argument( + "--glob-pattern", type=str, default=None, help="A glob pattern to filter files to read. Default: None" + ) args = parser.parse_args() @@ -42,74 +78,33 @@ def get_args(): def main(args): - - world_size, rank = int(os.environ["WORLD_SIZE"]), int(os.environ["RANK"]) - - # Remove stdout from all processes except main to not flood the stdout - if rank: - sys.stdout = open(os.devnull, "w") - - # Check if output directory exists - if not os.path.isdir(os.path.abspath(os.path.join(args.output_prefix, os.path.pardir))): - print(f"Creating {os.path.abspath(os.path.join(args.output_prefix, os.path.pardir))} directory...") - os.makedirs(os.path.abspath(os.path.join(args.output_prefix, os.path.pardir)), exist_ok=True) - - if args.input.endswith(".json"): # For processing JSON files (Cross compatibility with other projects) - ds = load_dataset("json", data_files=args.input) - ds = concatenate_datasets( - [ds[splits] for splits in ds.keys()] - ) # load_dataset returns DatasetDict and we want a Dataset + # Build datatrove reader + if args.readers == "hf": + datatrove_reader = HuggingFaceDatasetReader( + dataset=args.dataset, + text_key=args.column, + dataset_options={"split": args.split}, + ) else: - ds = load_dataset(args.input, split=args.split) - - ds = ds.shard(num_shards=world_size, index=rank, contiguous=True) - ds = ds.select_columns(args.column) - - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path) - token_dtype = np.int32 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else np.uint16 - - # Create tmp directory for worker outputs - tmp_folder = os.path.abspath(os.path.join(args.output_prefix, os.pardir, "tmp")) - os.makedirs(tmp_folder, exist_ok=True) - - print("Creating workers output files...") - worker_output_file = os.path.join(tmp_folder, f"worker_{rank}_input_ids.npy") - ds = ds.map( - lambda x: {"input_ids": tokenizer(x, add_special_tokens=args.add_special_tokens).input_ids}, - input_columns=args.column, - batched=True, - desc="Tokenizing Dataset", - remove_columns=[args.column], + datatrove_reader = JsonlReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern) + + preprocess_executor = LocalPipelineExecutor( + pipeline=[ + datatrove_reader, + DocumentTokenizer( + output_folder=args.output_folder, + tokenizer_name_or_path=args.tokenizer_name_or_path, + eos_token=args.eos_token, + shuffle=False, + max_tokens_per_file=1e9, + ), + ], + tasks=args.n_tasks, + logging_dir=args.logging_dir, ) - - worker_input_ids_file = open(worker_output_file, "wb") - for sample in ds: - np_array = np.array(sample["input_ids"], dtype=token_dtype) - worker_input_ids_file.write(np_array.tobytes(order="C")) - worker_input_ids_file.close() - - # Wait for all workers to process each shard of the Dataset - dist.barrier() - - # Only the main rank merges the worker files - if not rank: - output_file = f"{args.output_prefix}_input_ids.npy" - input_ids_file = open(output_file, "wb") - for worker_idx in tqdm(range(world_size), desc="Merging workers output files"): - worker_output_file = os.path.join(tmp_folder, f"worker_{worker_idx}_input_ids.npy") - with open(worker_output_file, "rb") as f: - shutil.copyfileobj(f, input_ids_file) - os.remove(worker_output_file) - - input_ids_file.close() - os.rmdir(tmp_folder) - print(f"Done! {args.input} processed dataset stored in {output_file}") - - else: # Close devnull stdout redirect - sys.stdout.close() + preprocess_executor.run() if __name__ == "__main__": _args = get_args() - dist.init_process_group(backend="gloo") main(_args)