From dcb866f3ddbe2a2d27e5f8fa6e231f2ad1cee391 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Mon, 16 Sep 2024 17:53:50 +0200 Subject: [PATCH 01/33] feat: combined datasets implementation --- src/modalities/config/config.py | 4 +++ src/modalities/dataloader/dataset.py | 38 ++++++++++++++++++++ src/modalities/dataloader/dataset_factory.py | 16 ++++++++- src/modalities/registry/components.py | 2 ++ 4 files changed, 59 insertions(+), 1 deletion(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index c9d98a6f..ba1b8a43 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -281,6 +281,10 @@ class PackedMemMapDatasetMegatronConfig(BaseModel): sample_key: str +class CombinedDatasetConfig(BaseModel): + datasets: List[PydanticDatasetIFType] + + class MMapIndexedDatasetConfig(BaseModel): path: Path skip_warmup: bool diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index fa855e5e..f47d71b0 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -368,3 +368,41 @@ def _generate_packing_index(self) -> List[Tuple[int, int]]: curr_offset = segment_offset curr_len = segment_len return index + + +class CombinedDataset(Dataset): + """Combines multiple datasets into one large dataset at runtime.""" + + def __init__(self, datasets: List[Dataset]): + """Initializes the CombinedDataset object, combining multiple datasets. + + Args: + datasets (List[Dataset]): A list of datasets to combine. + """ + self.datasets = datasets + self.cumulative_sizes = CombinedDataset._get_cummulated_sizes(datasets=datasets) + + @staticmethod + def _get_cummulated_sizes(datasets: List[Dataset]) -> List[int]: + total = 0 + cummulated_sizes = [0] + for dataset in datasets: + total += len(dataset) + cummulated_sizes.append(total) + return cummulated_sizes + + def _find_dataset_idx(self, idx: int) -> int: + for i, cumulative_size in enumerate(self.cumulative_sizes): + if idx < cumulative_size: + return i + raise IndexError(f"Index {idx} is out of bounds.") + + def __len__(self) -> int: + return self.cumulative_sizes[-1] + + def __getitem__(self, idx: int) -> Dict: + dataset_idx = self._find_dataset_idx(idx) + local_idx = idx - self.cumulative_sizes[dataset_idx - 1] + + sample = self.datasets[dataset_idx][local_idx] + return sample diff --git a/src/modalities/dataloader/dataset_factory.py b/src/modalities/dataloader/dataset_factory.py index 0483a4fd..a2a996d0 100644 --- a/src/modalities/dataloader/dataset_factory.py +++ b/src/modalities/dataloader/dataset_factory.py @@ -1,9 +1,11 @@ from pathlib import Path -from typing import Optional, Tuple +from typing import List, Optional, Tuple from transformers import PreTrainedTokenizer from modalities.dataloader.dataset import ( + CombinedDataset, + Dataset, DummyDataset, DummySampleConfig, MemMapDataset, @@ -94,3 +96,15 @@ def get_packed_mem_map_dataset_megatron( raw_data_path=raw_data_path, block_size=sequence_length + 1, sample_key=sample_key ) return dataset + + @staticmethod + def get_combined_dataset(datasets: List[Dataset]) -> Dataset: + """Factory method for creating a combined datset . + + Args: + datasets (List[Dataset]): List of datasets to combine. + + Returns: + Dataset: CombinedDataset object. + """ + return CombinedDataset(datasets=datasets) diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 9585dfac..6eaaa0c1 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -22,6 +22,7 @@ CheckpointedOptimizerConfig, CheckpointSavingConfig, CLMCrossEntropyLossConfig, + CombinedDatasetConfig, ConstantLRSchedulerConfig, CosineAnnealingLRSchedulerConfig, DistributedSamplerConfig, @@ -165,6 +166,7 @@ class ComponentEntity: PackedMemMapDatasetMegatronConfig, ), ComponentEntity("dataset", "dummy_dataset", DatasetFactory.get_dummy_dataset, DummyDatasetConfig), + ComponentEntity("dataset", "combined_dataset", DatasetFactory.get_combined_dataset, CombinedDatasetConfig), # samplers ComponentEntity("sampler", "distributed_sampler", DistributedSampler, DistributedSamplerConfig), # batch samplers From 63b370e5ee0e04674842d6559767e073bc49ed04 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Tue, 17 Sep 2024 10:30:56 +0200 Subject: [PATCH 02/33] feat: added DistributedSampler (unmodified pytorch version) --- src/modalities/dataloader/samplers.py | 144 +++++++++++++++++++++++++- 1 file changed, 142 insertions(+), 2 deletions(-) diff --git a/src/modalities/dataloader/samplers.py b/src/modalities/dataloader/samplers.py index 070b7f71..0cb5bbc4 100644 --- a/src/modalities/dataloader/samplers.py +++ b/src/modalities/dataloader/samplers.py @@ -1,6 +1,9 @@ -from typing import Optional +import math +from typing import Iterator, Optional, TypeVar -from torch.utils.data import BatchSampler, Sampler +import torch +import torch.distributed as dist +from torch.utils.data import BatchSampler, Dataset, Sampler class ResumableBatchSampler(Sampler): @@ -58,3 +61,140 @@ def batch_size(self) -> int: int: The batch size of the underlying batch sampler. """ return self.underlying_batch_sampler.batch_size + + +T_co = TypeVar("T_co", covariant=True) + + +class ResumableDistributedSampler(Sampler[T_co]): + r"""Sampler that restricts data loading to a subset of the dataset. + We adopted this class from pytorch's DistributedSampler class and added the ability to resume from a specific index. + source: https://github.com/pytorch/pytorch/blob/main/torch/utils/data/distributed.py + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each + process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a + :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the + original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size and that any instance of it always + returns the same elements in the same order. + + Args: + dataset: Dataset used for sampling. + num_replicas (int, optional): Number of processes participating in + distributed training. By default, :attr:`world_size` is retrieved from the + current distributed group. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + + .. warning:: + In distributed mode, calling the :meth:`set_epoch` method at + the beginning of each epoch **before** creating the :class:`DataLoader` iterator + is necessary to make shuffling work properly across multiple epochs. Otherwise, + the same ordering will be always used. + + Example:: + + >>> # xdoctest: +SKIP + >>> sampler = DistributedSampler(dataset) if is_distributed else None + >>> loader = DataLoader(dataset, shuffle=(sampler is None), + ... sampler=sampler) + >>> for epoch in range(start_epoch, n_epochs): + ... if is_distributed: + ... sampler.set_epoch(epoch) + ... train(loader) + """ + + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator[T_co]: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + r""" + Set the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch From 51f3fefd362f3ce6790493e284e7e6c7b2d95f33 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Wed, 18 Sep 2024 10:32:48 +0200 Subject: [PATCH 03/33] refactor: DistributedSampler from pytorch --- docs/components/components.md | 1 + src/modalities/dataloader/samplers.py | 55 ++++++++++++++------------- src/modalities/registry/components.py | 2 +- 3 files changed, 30 insertions(+), 28 deletions(-) diff --git a/docs/components/components.md b/docs/components/components.md index 4e41110e..a7079628 100644 --- a/docs/components/components.md +++ b/docs/components/components.md @@ -56,6 +56,7 @@ | dataset | mem_map_dataset | [DatasetFactory.get_mem_map_dataset](../../src/modalities/dataloader/dataset_factory.py)| [MemMapDatasetConfig](../../src/modalities/config/config.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | MemMap Dataset | | dataset | packed_mem_map_dataset_continuous | [DatasetFactory.get_packed_mem_map_dataset_continuous](../../src/modalities/dataloader/dataset_factory.py)| [PackedMemMapDatasetContinuousConfig](../../src/modalities/config/config.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | Packed Memory Mapped Dataset Continuous | | dataset | dummy_dataset | [DatasetFactory.get_dummy_dataset](../../src/modalities/dataloader/dataset_factory.py)| [DummyDatasetConfig](../../src/modalities/dataloader/dataset.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | Dummy dataset creating random samples of specified shape | +| dataset | combined | [DatasetFactory.get_combined_dataset](../../src/modalities/dataloader/dataset_factory.py)| [CombinedDatasetConfig](../../src/modalities/dataloader/dataset.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | Dataset implementation combining multiple datasets into one. | ## Data sampling diff --git a/src/modalities/dataloader/samplers.py b/src/modalities/dataloader/samplers.py index 0cb5bbc4..8eb35e8c 100644 --- a/src/modalities/dataloader/samplers.py +++ b/src/modalities/dataloader/samplers.py @@ -120,39 +120,38 @@ class ResumableDistributedSampler(Sampler[T_co]): def __init__( self, dataset: Dataset, + rank: int, num_replicas: Optional[int] = None, - rank: Optional[int] = None, - shuffle: bool = True, - seed: int = 0, - drop_last: bool = False, + epoch: Optional[int] = 0, + shuffle: Optional[bool] = False, + seed: Optional[int] = 0, + drop_last: Optional[bool] = False, + skip_num_global_samples: Optional[int] = 0, ) -> None: - if num_replicas is None: - if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available") - num_replicas = dist.get_world_size() - if rank is None: - if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available") - rank = dist.get_rank() - if rank >= num_replicas or rank < 0: - raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") + num_replicas = dist.get_world_size() + self.rank = rank self.dataset = dataset self.num_replicas = num_replicas - self.rank = rank - self.epoch = 0 + self.epoch = epoch self.drop_last = drop_last + self.skip_num_global_samples = skip_num_global_samples + + self.global_num_samples = len(self.dataset) - self.skip_num_global_samples # If the dataset length is evenly divisible by # of replicas, then there # is no need to drop any data, since the dataset will be split equally. - if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] + if self.drop_last and self.global_num_samples % self.num_replicas != 0: # type: ignore[arg-type] # Split to nearest available length that is evenly divisible. # This is to ensure each rank receives the same amount of data when # using this Sampler. - self.num_samples = math.ceil( - (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + self.local_num_samples = math.ceil( + (self.global_num_samples - self.num_replicas) / self.num_replicas # type: ignore[arg-type] ) else: - self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] - self.total_size = self.num_samples * self.num_replicas + # if this is not integer divisible, we will add padding by reusing the beginning of the data + self.local_num_samples = math.ceil(self.global_num_samples / self.num_replicas) # type: ignore[arg-type] + + # the actual number of samples we will be iterating over + self.global_num_samples_effective = self.local_num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed @@ -165,26 +164,28 @@ def __iter__(self) -> Iterator[T_co]: else: indices = list(range(len(self.dataset))) # type: ignore[arg-type] + indices = indices[self.skip_num_global_samples :] + if not self.drop_last: # add extra samples to make it evenly divisible - padding_size = self.total_size - len(indices) + padding_size = self.global_num_samples_effective - len(indices) if padding_size <= len(indices): indices += indices[:padding_size] else: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. - indices = indices[: self.total_size] - assert len(indices) == self.total_size + indices = indices[: self.global_num_samples_effective] + assert len(indices) == self.global_num_samples_effective # subsample - indices = indices[self.rank : self.total_size : self.num_replicas] - assert len(indices) == self.num_samples + indices = indices[self.rank : self.global_num_samples_effective : self.num_replicas] + assert len(indices) == self.local_num_samples return iter(indices) def __len__(self) -> int: - return self.num_samples + return self.local_num_samples def set_epoch(self, epoch: int) -> None: r""" diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index f088eb93..959314c9 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -176,7 +176,7 @@ class ComponentEntity: PackedMemMapDatasetMegatronConfig, ), ComponentEntity("dataset", "dummy_dataset", DatasetFactory.get_dummy_dataset, DummyDatasetConfig), - ComponentEntity("dataset", "combined_dataset", DatasetFactory.get_combined_dataset, CombinedDatasetConfig), + ComponentEntity("dataset", "combined", DatasetFactory.get_combined_dataset, CombinedDatasetConfig), # samplers ComponentEntity("sampler", "distributed_sampler", DistributedSampler, DistributedSamplerConfig), # batch samplers From a9a5d935f618e33dbe7ae3dd5fe1d361d9c64319 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Wed, 18 Sep 2024 10:33:19 +0200 Subject: [PATCH 04/33] refactor: vectorized packed index generation --- src/modalities/dataloader/dataset.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index f47d71b0..2baf834d 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -328,10 +328,21 @@ def _generate_packing_index(self) -> List[Tuple[int, int]]: # of the subsequent sample). num_samples = (total_tokens - self.block_size) // (self.block_size - 1) + 1 # given num_samples we calculate the starting index and length of each sample as tuple. - return [ - ((i * self.block_size - i) * self._token_size_in_bytes, self.block_size * self._token_size_in_bytes) - for i in range(num_samples) - ] + # return [ + # ((i * self.block_size - i) * self._token_size_in_bytes, self.block_size * self._token_size_in_bytes) + # for i in range(num_samples) + # ] + + # Create an array of indices (i values) + i_array = np.arange(num_samples) + + # Vectorized operations + first_component = (i_array * self.block_size - i_array) * self._token_size_in_bytes + second_component = np.full(num_samples, self.block_size * self._token_size_in_bytes) + + # Combine both components into a 2D array of tuples (or list of tuples if needed) + result = np.stack((first_component, second_component), axis=1) + return result class PackedMemMapDatasetMegatron(PackedMemMapDatasetBase): @@ -394,7 +405,7 @@ def _get_cummulated_sizes(datasets: List[Dataset]) -> List[int]: def _find_dataset_idx(self, idx: int) -> int: for i, cumulative_size in enumerate(self.cumulative_sizes): if idx < cumulative_size: - return i + return i - 1 raise IndexError(f"Index {idx} is out of bounds.") def __len__(self) -> int: @@ -402,7 +413,7 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> Dict: dataset_idx = self._find_dataset_idx(idx) - local_idx = idx - self.cumulative_sizes[dataset_idx - 1] + local_idx = idx - self.cumulative_sizes[dataset_idx] sample = self.datasets[dataset_idx][local_idx] return sample From 3198e4d63846a54521fd8e7504d844a4f406aa19 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Tue, 24 Sep 2024 11:50:06 +0200 Subject: [PATCH 05/33] feat: added test coverage for CombinedDataset --- src/modalities/dataloader/dataset.py | 7 +++++- tests/dataloader/test_combined_dataset.py | 27 +++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 tests/dataloader/test_combined_dataset.py diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 2baf834d..85704e90 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -382,7 +382,12 @@ def _generate_packing_index(self) -> List[Tuple[int, int]]: class CombinedDataset(Dataset): - """Combines multiple datasets into one large dataset at runtime.""" + """Combines multiple datasets into one large dataset at runtime. + + Note: When using this class to combine multiple `PackedMemMapDatasetes`, then each packed sample + is packed from a single dataset (i.e., the samples are not mixed between datasets). + In the Dataloader a batch will still contain packed samples from different datasets. + """ def __init__(self, datasets: List[Dataset]): """Initializes the CombinedDataset object, combining multiple datasets. diff --git a/tests/dataloader/test_combined_dataset.py b/tests/dataloader/test_combined_dataset.py new file mode 100644 index 00000000..4aee5e6c --- /dev/null +++ b/tests/dataloader/test_combined_dataset.py @@ -0,0 +1,27 @@ +import pytest + +from modalities.dataloader.dataset import CombinedDataset + + +@pytest.fixture +def dummy_dataset_1() -> list[int]: + return list(range(10)) + + +@pytest.fixture +def dummy_dataset_2() -> list[int]: + return list(range(10, 15)) + + +def test_combined_dataset(dummy_dataset_1: list[int], dummy_dataset_2: list[int]): + combined_dataset = CombinedDataset(datasets=[dummy_dataset_1, dummy_dataset_2]) + + # check that length is calculated correctly + assert len(combined_dataset) == 15 + + # check that the elements are iterated over in order + assert [i for i in combined_dataset] == list(range(15)) + + # check that we throw an error when trying to access an index that is out of bounds + with pytest.raises(IndexError): + combined_dataset[15] From 299d2e6b29eb5d433c2ffb352c8000dcc7674f00 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Tue, 24 Sep 2024 15:42:47 +0200 Subject: [PATCH 06/33] refactor: moved sampler tests --- tests/dataloader/samplers/__init__.py | 0 tests/dataloader/{ => samplers}/test_samplers.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/dataloader/samplers/__init__.py rename tests/dataloader/{ => samplers}/test_samplers.py (100%) diff --git a/tests/dataloader/samplers/__init__.py b/tests/dataloader/samplers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/dataloader/test_samplers.py b/tests/dataloader/samplers/test_samplers.py similarity index 100% rename from tests/dataloader/test_samplers.py rename to tests/dataloader/samplers/test_samplers.py From 915ec8fdbe81114d16f0fe124cc598587fcb13ac Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Tue, 24 Sep 2024 15:43:35 +0200 Subject: [PATCH 07/33] feat: implemented distributed sampler tests --- .../samplers/test_distributed_samplers.py | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 tests/dataloader/samplers/test_distributed_samplers.py diff --git a/tests/dataloader/samplers/test_distributed_samplers.py b/tests/dataloader/samplers/test_distributed_samplers.py new file mode 100644 index 00000000..0ca36c51 --- /dev/null +++ b/tests/dataloader/samplers/test_distributed_samplers.py @@ -0,0 +1,88 @@ +import math + +import pytest + +from modalities.dataloader.samplers import ResumableDistributedSampler + + +@pytest.mark.parametrize( + "num_samples, epoch, shuffle, seed, drop_last, skip_num_global_samples", + [ + (30, 0, False, 0, False, 0), + (30, 0, False, 0, True, 0), # drop_last has no effect because integer divisible + (30, 0, False, 0, False, 9), + (30, 0, False, 0, True, 9), # drop_last has no effect because integer divisible + (30, 0, False, 0, True, 10), # drop_last has an effect because not integer divisible + (30, 0, False, 0, False, 10), # we have to reuse the initial samples (1 sample) + ], +) +def test_dropping_and_reusing( + num_samples: int, epoch: int, shuffle: bool, seed: int, drop_last: bool, skip_num_global_samples: int +): + dataset = list(range(num_samples)) + num_replicas = 3 # world size + samplers = [ + ResumableDistributedSampler( + dataset=dataset, + rank=rank, + num_replicas=num_replicas, + epoch=epoch, + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + skip_num_global_samples=skip_num_global_samples, + ) + for rank in range(num_replicas) + ] + + samples = [[dataset[i] for i in sampler] for sampler in samplers] + + if drop_last: + # when drop_last true, we drop the last samples so that every data parallel rank + # has the same number of samples. + # Note that also means that the last, remaining samples (i.e., maximum num_ranks -1) + # are not used at all + cut_off_samples = len(dataset) - (len(dataset) - skip_num_global_samples) % num_replicas + padded_samples = [] + else: + cut_off_samples = len(dataset) + samples_left = len(dataset) - skip_num_global_samples + padding_size = math.ceil(samples_left / num_replicas) * num_replicas - samples_left + # when drop_last false, we reuse the last samples (i.e., maximum num_ranks -1) + # so that every data parallel ran, has a full last batch + padded_samples = dataset[:padding_size] + + assert dataset[skip_num_global_samples:cut_off_samples] + padded_samples == list( + s for t in zip(*samples) for s in t + ) + + +@pytest.mark.parametrize( + "num_samples, epoch, shuffle, seed, drop_last, skip_num_global_samples", + [ + (30, 0, True, 0, True, 0), + ], +) +def test_shuffling( + num_samples: int, epoch: int, shuffle: bool, seed: int, drop_last: bool, skip_num_global_samples: int +): + dataset = list(range(num_samples)) + num_replicas = 3 # world size + samplers = [ + ResumableDistributedSampler( + dataset=dataset, + rank=rank, + num_replicas=num_replicas, + epoch=epoch, + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + skip_num_global_samples=skip_num_global_samples, + ) + for rank in range(num_replicas) + ] + + samples = [[dataset[i] for i in sampler] for sampler in samplers] + samples_flat = [s for t in zip(*samples) for s in t] + + assert set(samples_flat) == set(dataset) From a25a788d54cde1b09997812ffda27640091fcd14 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Tue, 24 Sep 2024 15:44:31 +0200 Subject: [PATCH 08/33] refactor: refactored ResumableDistributedSampler --- src/modalities/config/config.py | 13 ++++- src/modalities/dataloader/samplers.py | 80 +++++++++++---------------- src/modalities/registry/components.py | 5 ++ 3 files changed, 47 insertions(+), 51 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index ea30d519..8672f391 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -265,6 +265,17 @@ class DistributedSamplerConfig(BaseModel): drop_last: Literal[True] = True +class ResumableDistributedSamplerConfig(BaseModel): + dataset: PydanticDatasetIFType + rank: Annotated[int, Field(strict=True, ge=0)] + num_replicas: Annotated[int, Field(strict=True, ge=0)] = None + epoch: Annotated[int, Field(strict=True, ge=0)] = 0 + shuffle: Optional[bool] = False + seed: Optional[int] = 0 + drop_last: Literal[True] = True + skip_num_global_samples: Annotated[int, Field(strict=True, ge=0)] = 0 + + class MemMapDatasetConfig(BaseModel): raw_data_path: FilePath index_path: Optional[FilePath] = None @@ -317,8 +328,6 @@ class LLMDataLoaderConfig(BaseModel): collate_fn: Optional[PydanticCollateFnIFType] = None num_workers: Annotated[int, Field(strict=True, ge=0)] pin_memory: bool - skip_num_batches: Optional[int] = 0 - fixed_num_batches: Optional[int] = None class RepeatingDataLoaderConfig(BaseModel): diff --git a/src/modalities/dataloader/samplers.py b/src/modalities/dataloader/samplers.py index 8eb35e8c..6c6c1321 100644 --- a/src/modalities/dataloader/samplers.py +++ b/src/modalities/dataloader/samplers.py @@ -67,7 +67,7 @@ def batch_size(self) -> int: class ResumableDistributedSampler(Sampler[T_co]): - r"""Sampler that restricts data loading to a subset of the dataset. + """Sampler that restricts data loading to a subset of the dataset. We adopted this class from pytorch's DistributedSampler class and added the ability to resume from a specific index. source: https://github.com/pytorch/pytorch/blob/main/torch/utils/data/distributed.py @@ -80,41 +80,6 @@ class ResumableDistributedSampler(Sampler[T_co]): .. note:: Dataset is assumed to be of constant size and that any instance of it always returns the same elements in the same order. - - Args: - dataset: Dataset used for sampling. - num_replicas (int, optional): Number of processes participating in - distributed training. By default, :attr:`world_size` is retrieved from the - current distributed group. - rank (int, optional): Rank of the current process within :attr:`num_replicas`. - By default, :attr:`rank` is retrieved from the current distributed - group. - shuffle (bool, optional): If ``True`` (default), sampler will shuffle the - indices. - seed (int, optional): random seed used to shuffle the sampler if - :attr:`shuffle=True`. This number should be identical across all - processes in the distributed group. Default: ``0``. - drop_last (bool, optional): if ``True``, then the sampler will drop the - tail of the data to make it evenly divisible across the number of - replicas. If ``False``, the sampler will add extra indices to make - the data evenly divisible across the replicas. Default: ``False``. - - .. warning:: - In distributed mode, calling the :meth:`set_epoch` method at - the beginning of each epoch **before** creating the :class:`DataLoader` iterator - is necessary to make shuffling work properly across multiple epochs. Otherwise, - the same ordering will be always used. - - Example:: - - >>> # xdoctest: +SKIP - >>> sampler = DistributedSampler(dataset) if is_distributed else None - >>> loader = DataLoader(dataset, shuffle=(sampler is None), - ... sampler=sampler) - >>> for epoch in range(start_epoch, n_epochs): - ... if is_distributed: - ... sampler.set_epoch(epoch) - ... train(loader) """ def __init__( @@ -128,7 +93,11 @@ def __init__( drop_last: Optional[bool] = False, skip_num_global_samples: Optional[int] = 0, ) -> None: - num_replicas = dist.get_world_size() + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + self.rank = rank self.dataset = dataset self.num_replicas = num_replicas @@ -160,29 +129,42 @@ def __iter__(self) -> Iterator[T_co]: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) - indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + indices_full = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] else: - indices = list(range(len(self.dataset))) # type: ignore[arg-type] + indices_full = list(range(len(self.dataset))) # type: ignore[arg-type] - indices = indices[self.skip_num_global_samples :] + indices_without_skipped = indices_full[self.skip_num_global_samples :] if not self.drop_last: # add extra samples to make it evenly divisible - padding_size = self.global_num_samples_effective - len(indices) - if padding_size <= len(indices): - indices += indices[:padding_size] + padding_size = self.global_num_samples_effective - len(indices_without_skipped) + if padding_size <= len(indices_full): + indices_without_skipped += indices_full[:padding_size] # we want to reuse the beginning of the data else: - indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + # if the padding size is larger than the data, we create an extended index by repeating the indices + indices_without_skipped += (indices_full * math.ceil(padding_size / len(indices_full)))[:padding_size] else: # remove tail of data to make it evenly divisible. - indices = indices[: self.global_num_samples_effective] - assert len(indices) == self.global_num_samples_effective + indices_without_skipped = indices_without_skipped[: self.global_num_samples_effective] + + if len(indices_without_skipped) != self.global_num_samples_effective: + raise ValueError( + f"global_num_samples_effective ({self.global_num_samples_effective}) does not match the actual" + f"number of samples ({len(indices_without_skipped)})" + ) # subsample - indices = indices[self.rank : self.global_num_samples_effective : self.num_replicas] - assert len(indices) == self.local_num_samples + indices_without_skipped = indices_without_skipped[ + self.rank : self.global_num_samples_effective : self.num_replicas + ] + + if len(indices_without_skipped) != self.local_num_samples: + raise ValueError( + f"local_num_samples ({self.local_num_samples}) does not match the actual" + f"number of samples ({len(indices_without_skipped)})" + ) - return iter(indices) + return iter(indices_without_skipped) def __len__(self) -> int: return self.local_num_samples diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 959314c9..89079f04 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -42,6 +42,7 @@ PreTrainedHFTokenizerConfig, PreTrainedSPTokenizerConfig, RepeatingDataLoaderConfig, + ResumableDistributedSamplerConfig, RichProgressSubscriberConfig, RichResultSubscriberConfig, SaveEveryKStepsCheckpointingStrategyConfig, @@ -54,6 +55,7 @@ from modalities.dataloader.dataloader_factory import DataloaderFactory from modalities.dataloader.dataset import DummyDatasetConfig from modalities.dataloader.dataset_factory import DatasetFactory +from modalities.dataloader.samplers import ResumableDistributedSampler from modalities.logging_broker.subscriber_impl.subscriber_factory import ( ProgressSubscriberFactory, ResultsSubscriberFactory, @@ -179,6 +181,9 @@ class ComponentEntity: ComponentEntity("dataset", "combined", DatasetFactory.get_combined_dataset, CombinedDatasetConfig), # samplers ComponentEntity("sampler", "distributed_sampler", DistributedSampler, DistributedSamplerConfig), + ComponentEntity( + "sampler", "resumable_distributed_sampler", ResumableDistributedSampler, ResumableDistributedSamplerConfig + ), # batch samplers ComponentEntity("batch_sampler", "default", BatchSampler, BatchSamplerConfig), # collators From dd5316e2e6019640dd3314f3ef4820da6f1aada9 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Tue, 24 Sep 2024 15:46:00 +0200 Subject: [PATCH 09/33] refactor: commented out old sample skipping in dataloader --- .../dataloader/dataloader_factory.py | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/src/modalities/dataloader/dataloader_factory.py b/src/modalities/dataloader/dataloader_factory.py index e327f11e..1f12ca47 100644 --- a/src/modalities/dataloader/dataloader_factory.py +++ b/src/modalities/dataloader/dataloader_factory.py @@ -1,11 +1,9 @@ -from typing import Callable, Optional +from typing import Callable from torch.utils.data import BatchSampler from torch.utils.data.dataset import Dataset from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader -from modalities.dataloader.samplers import ResumableBatchSampler -from modalities.exceptions import ConfigError class DataloaderFactory: @@ -17,8 +15,6 @@ def get_dataloader( collate_fn: Callable, num_workers: int, pin_memory: bool, - skip_num_batches: Optional[int] = 0, - fixed_num_batches: Optional[int] = None, ) -> LLMDataLoader: """ Factory method for the instantiation of LLMDataLoader. @@ -44,19 +40,20 @@ def get_dataloader( LLMDataLoader: Instance of LLMDataLoader """ - batch_sampler = ResumableBatchSampler( - start_index=skip_num_batches, underlying_batch_sampler=batch_sampler, max_num_elements=fixed_num_batches - ) + # batch_sampler = ResumableBatchSampler( + # start_index=skip_num_batches, underlying_batch_sampler=batch_sampler, max_num_elements=fixed_num_batches + # ) - if fixed_num_batches is not None and fixed_num_batches <= skip_num_batches: - raise ConfigError("fixed_num_batches must be larger than skip_num_batches") + # if fixed_num_batches is not None and fixed_num_batches <= skip_num_batches: + # raise ConfigError("fixed_num_batches must be larger than skip_num_batches") - # make sure that the batch sampler has enough elements such that we can fix the number of batches to num_batches - if fixed_num_batches is not None and len(batch_sampler) < fixed_num_batches - skip_num_batches: - raise ConfigError( - f"The dataloader contains only {len(batch_sampler)} batches, which is less than " - f"specified fixed amount of batches of {fixed_num_batches}." - ) + # # make sure that the batch sampler has enough elements such + # # that we can fix the number of batches to num_batches + # if fixed_num_batches is not None and len(batch_sampler) < fixed_num_batches - skip_num_batches: + # raise ConfigError( + # f"The dataloader contains only {len(batch_sampler)} batches, which is less than " + # f"specified fixed amount of batches of {fixed_num_batches}." + # ) dataloader = LLMDataLoader( dataloader_tag=dataloader_tag, From c7a9cfc1d309d4782dfa78cd31464b22f59fa215 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Tue, 24 Sep 2024 15:46:18 +0200 Subject: [PATCH 10/33] feat: added new sampling strategy to config lorem ipsum --- config_files/training/config_lorem_ipsum.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/config_files/training/config_lorem_ipsum.yaml b/config_files/training/config_lorem_ipsum.yaml index 6f4ae630..ffa43f2d 100644 --- a/config_files/training/config_lorem_ipsum.yaml +++ b/config_files/training/config_lorem_ipsum.yaml @@ -72,7 +72,6 @@ train_dataloader: num_workers: 2 pin_memory: true dataloader_tag: train - skip_num_batches: ${settings.training_progress.local_num_seen_batches} dataset: instance_key: train_dataset pass_type: BY_REFERENCE @@ -84,16 +83,17 @@ train_dataloader: drop_last: true sampler: component_key: sampler - variant_key: distributed_sampler + variant_key: resumable_distributed_sampler config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE rank: ${settings.cuda_env.global_rank} num_replicas: ${settings.cuda_env.world_size} shuffle: true - drop_last: true seed: 42 - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE + drop_last: true + skip_num_global_samples: 0 collate_fn: instance_key: collate_fn pass_type: BY_REFERENCE From 16f684e0eca08d4f3e776332a013456450c0967b Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Wed, 25 Sep 2024 11:47:30 +0200 Subject: [PATCH 11/33] feat: added more tests for the distributed sampler --- .../samplers/test_distributed_samplers.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/dataloader/samplers/test_distributed_samplers.py b/tests/dataloader/samplers/test_distributed_samplers.py index 0ca36c51..dc7ab092 100644 --- a/tests/dataloader/samplers/test_distributed_samplers.py +++ b/tests/dataloader/samplers/test_distributed_samplers.py @@ -19,6 +19,7 @@ def test_dropping_and_reusing( num_samples: int, epoch: int, shuffle: bool, seed: int, drop_last: bool, skip_num_global_samples: int ): + # we test that drop_last and or reusing the initial samples works as expected dataset = list(range(num_samples)) num_replicas = 3 # world size samplers = [ @@ -66,6 +67,8 @@ def test_dropping_and_reusing( def test_shuffling( num_samples: int, epoch: int, shuffle: bool, seed: int, drop_last: bool, skip_num_global_samples: int ): + # we test that shuffling leads to a different order of the samples and all samples of the + # original dataset are used dataset = list(range(num_samples)) num_replicas = 3 # world size samplers = [ @@ -86,3 +89,62 @@ def test_shuffling( samples_flat = [s for t in zip(*samples) for s in t] assert set(samples_flat) == set(dataset) + assert samples_flat != dataset + + +@pytest.mark.parametrize( + "num_samples, epoch, shuffle, seed, drop_last, skip_num_global_samples", + [ + (30, 0, False, 0, True, 0), + (30, 0, True, 0, True, 0), + ], +) +def test_ordering_with_different_world_sizes_and_shuffling( + num_samples: int, epoch: int, shuffle: bool, seed: int, drop_last: bool, skip_num_global_samples: int +): + # 1) we test that WITHOUT shuffling the order of samples is the same as in the original dataset + # for different world sizes. + # 2) we test that WITH shuffling the order of samples is the same for different world sizes + # but not the same order as in the original dataset. + dataset = list(range(num_samples)) + samplers_3 = [ + ResumableDistributedSampler( + dataset=dataset, + rank=rank, + num_replicas=3, + epoch=epoch, + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + skip_num_global_samples=skip_num_global_samples, + ) + for rank in range(3) + ] + + samplers_6 = [ + ResumableDistributedSampler( + dataset=dataset, + rank=rank, + num_replicas=6, + epoch=epoch, + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + skip_num_global_samples=skip_num_global_samples, + ) + for rank in range(6) + ] + + samples_3 = [[dataset[i] for i in sampler] for sampler in samplers_3] + samples_flat_3 = [s for t in zip(*samples_3) for s in t] + + samples_6 = [[dataset[i] for i in sampler] for sampler in samplers_6] + samples_flat_6 = [s for t in zip(*samples_6) for s in t] + + if not shuffle: + assert dataset == samples_flat_3 + assert dataset == samples_flat_6 + else: + assert samples_flat_3 == samples_flat_6 + assert set(samples_flat_3) == set(dataset) + assert samples_flat_6 != dataset From d21e0df30c33e76ed6156b00e80d0bef320acd93 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Wed, 25 Sep 2024 11:55:11 +0200 Subject: [PATCH 12/33] chore: added documentation to ResumableDistributedSampler --- src/modalities/dataloader/samplers.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/modalities/dataloader/samplers.py b/src/modalities/dataloader/samplers.py index 6c6c1321..cb963a4c 100644 --- a/src/modalities/dataloader/samplers.py +++ b/src/modalities/dataloader/samplers.py @@ -93,6 +93,26 @@ def __init__( drop_last: Optional[bool] = False, skip_num_global_samples: Optional[int] = 0, ) -> None: + """Instantiates a distributed and resumable Sampler object. + + Args: + dataset (Dataset): The dataset to sample from. + rank (int): The global rank of the current process. + num_replicas (int, optional): Number of replicas. + This usually equals the world size. Defaults to None. + epoch (int, optional): Current epoch. Defaults to 0. + shuffle (bool, optional): Boolean flag whether to shuffle the data. Defaults to False. + seed (int, optional): Seed for the shuffling. Defaults to 0. + drop_last (bool, optional): Boolean flag indicating whether to drop the last samples + that cannot be distributed over all ranks (i.e., maximum world size - samples). + If drop_last is false padding is applied for these samples, by resampling the initial samples. + Defaults to False. + skip_num_global_samples (int, optional): Number of samples to skip, e.g., due to warmstart. + Defaults to 0. + + Raises: + RuntimeError: Requires distributed package to be available if num_replicas is None. + """ if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") From fb9deea9dce107a6fa7c3a31b3f799c3289fc99b Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Wed, 25 Sep 2024 16:17:45 +0200 Subject: [PATCH 13/33] refactor: the PackedMemMapDatasetContinuous does not load the index by default anymore. --- .../dataloader/create_packed_data.py | 27 ++++++--- src/modalities/dataloader/dataset.py | 60 +++++++++++-------- 2 files changed, 54 insertions(+), 33 deletions(-) diff --git a/src/modalities/dataloader/create_packed_data.py b/src/modalities/dataloader/create_packed_data.py index 33775fcd..fda2fd63 100644 --- a/src/modalities/dataloader/create_packed_data.py +++ b/src/modalities/dataloader/create_packed_data.py @@ -323,7 +323,7 @@ class EmbeddedStreamData: TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES = 4 HEADER_SIZE_IN_BYTES = DATA_SECTION_LENGTH_IN_BYTES + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES - def __init__(self, data_path: Path): + def __init__(self, data_path: Path, load_index: Optional[bool] = True): """ Initializes an EmbeddedStreamData object. @@ -352,14 +352,27 @@ def __init__(self, data_path: Path): self.token_size_in_bytes = int.from_bytes(token_size_as_bytes, byteorder="little", signed=False) # get index - f.seek(self.HEADER_SIZE_IN_BYTES + self.data_len) - pkl_encoded_index = f.read() - # contains the start offset and length of each segment - # as byte positions in the data section - self.index_base: List[Tuple[int, int]] = pickle.loads(pkl_encoded_index) + if load_index: + f.seek(self.HEADER_SIZE_IN_BYTES + self.data_len) + pkl_encoded_index = f.read() + # contains the start offset and length of each segment + # as byte positions in the data section + self._index_base: List[Tuple[int, int]] = pickle.loads(pkl_encoded_index) + else: + self._index_base = None # initialize memmapped data section - self.data = np.memmap(self._data_path, mode="r", offset=self.HEADER_SIZE_IN_BYTES, shape=(self.data_len,)) + self._data = np.memmap(self._data_path, mode="r", offset=self.HEADER_SIZE_IN_BYTES, shape=(self.data_len,)) + + @property + def index_base(self) -> List[Tuple[int, int]]: + if self._index_base is None: + raise ValueError("Index was not loaded. Set `load_index=True` during initialization.") + return self._index_base + + @property + def data(self) -> np.ndarray: + return self._data def join_embedded_stream_data(stream_data: List[EmbeddedStreamData], target_file: Path, chunk_size: int = 2048): diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 85704e90..f3139895 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -206,7 +206,7 @@ class PackedMemMapDatasetBase(Dataset): } type_converter_for_torch = {1: np.uint8, 2: np.int32, 4: np.int64} - def __init__(self, raw_data_path: Path, sample_key: str): + def __init__(self, raw_data_path: Path, sample_key: str, load_index: Optional[bool] = True): """ Initializes the PackedMemMapDatasetBase object. @@ -214,6 +214,7 @@ def __init__(self, raw_data_path: Path, sample_key: str): raw_data_path (Path): Path to a packed binary file (*.pbin). Use `modalities data pack_encoded_data` to create one based on a JSONL-file. sample_key (str): The key to access the sample in the BatchEncoding. + load_index (bool, optional): Flag indicating whether to load the index. Defaults to True. Raises: RuntimeError: If the token representation with the given size is not supported. @@ -226,16 +227,16 @@ def __init__(self, raw_data_path: Path, sample_key: str): this needs to get replaced with a list of sample keys! """ super().__init__(raw_data_path=raw_data_path, sample_key=sample_key) - self._embedded_stream_data = EmbeddedStreamData(raw_data_path) + self._embedded_stream_data = EmbeddedStreamData(raw_data_path, load_index=load_index) self._token_size_in_bytes = self._embedded_stream_data.token_size_in_bytes try: self._token_dtype_on_disk = self.np_dtype_of_tokens_on_disk_from_bytes[self._token_size_in_bytes] self._token_dtype_in_ram = self.type_converter_for_torch[self._token_size_in_bytes] - except KeyError: + except KeyError as e: raise RuntimeError( f"Encountered a required token representation with {self._token_size_in_bytes}," " which is not supported. Consider using a smaller vocabulary." - ) + ) from e self._index = self._generate_packing_index() def _generate_packing_index(self) -> List[Tuple[int, int]]: @@ -292,7 +293,7 @@ def __getitem__(self, idx: int) -> BatchEncoding: class PackedMemMapDatasetContinuous(PackedMemMapDatasetBase): """PackedMemMapDatasetContinuous class.""" - def __init__(self, raw_data_path: Path, sample_key: str, block_size: int): + def __init__(self, raw_data_path: Path, sample_key: str, block_size: int, load_index: Optional[bool] = False): """ Initializes the PackedMemMapDatasetContinuous object. @@ -301,12 +302,38 @@ def __init__(self, raw_data_path: Path, sample_key: str, block_size: int): Use `modalities data pack_encoded_data` to create one based on a JSONL-file. sample_key (str): The key to access the sample in the BatchEncoding. block_size (int): The size of the block. + load_index (bool, optional): Flag indicating whether to load the index. + This is only needed for debugging purposes to index the original documents. + The continuous packing does not need to load the index and should be + deactivated as it significantly increases the instantiation time. Defaults to False. Returns: None """ self.block_size = block_size - super().__init__(raw_data_path=raw_data_path, sample_key=sample_key) + # TODO passing the load_index flag does not really comply with the inversion + # of control principle. We should refactor this in the future. + super().__init__(raw_data_path=raw_data_path, sample_key=sample_key, load_index=load_index) + + @staticmethod + def _create_packed_index(total_tokens: int, block_size: int, token_size_in_bytes: int) -> List[Tuple[int, int]]: + # Given a fixed number of samples we can compute the total number of tokens as + # num_tokens = block_size + (block_size-1) * (num_samples-1) + # as the first sample always needs block_size many tokens and the following samples + # each need block_size-1 many tokens (since we can reuse the last target token as the first input token + # of the subsequent sample). + num_samples = (total_tokens - block_size) // (block_size - 1) + 1 + # create an index array of the form [0, 1, 2, ..., num_samples-1] + i_array = np.arange(num_samples) + # Vectorized operations + # create the starting byte position of each sample + first_component = (i_array * block_size - i_array) * token_size_in_bytes + # create the second component, which is the length of each sample in bytes + second_component = np.full(num_samples, block_size * token_size_in_bytes) + + # Combine both components into a 2D array of tuples (or list of tuples if needed) + result = np.stack((first_component, second_component), axis=1) + return result def _generate_packing_index(self) -> List[Tuple[int, int]]: # Generates the packing index for the dataset. @@ -321,27 +348,8 @@ def _generate_packing_index(self) -> List[Tuple[int, int]]: ) if self.block_size < 2: raise ValueError("Block size must be at least 2.") - # Given a fixed number of samples we can compute the total number of tokens as - # num_tokens = block_size + (block_size-1) * (num_samples-1) - # as the first sample always needs block_size many tokens and the following samples - # each need block_size-1 many tokens (since we can reuse the last target token as the first input token - # of the subsequent sample). - num_samples = (total_tokens - self.block_size) // (self.block_size - 1) + 1 - # given num_samples we calculate the starting index and length of each sample as tuple. - # return [ - # ((i * self.block_size - i) * self._token_size_in_bytes, self.block_size * self._token_size_in_bytes) - # for i in range(num_samples) - # ] - - # Create an array of indices (i values) - i_array = np.arange(num_samples) - - # Vectorized operations - first_component = (i_array * self.block_size - i_array) * self._token_size_in_bytes - second_component = np.full(num_samples, self.block_size * self._token_size_in_bytes) - # Combine both components into a 2D array of tuples (or list of tuples if needed) - result = np.stack((first_component, second_component), axis=1) + result = self._create_packed_index(total_tokens, self.block_size, self._token_size_in_bytes) return result From 2823745c2cff2f52511848724fa37474cc3a3f6b Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Wed, 25 Sep 2024 16:18:06 +0200 Subject: [PATCH 14/33] feat: added test for dataset packing --- tests/dataloader/test_packed_dataset.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/dataloader/test_packed_dataset.py b/tests/dataloader/test_packed_dataset.py index 9c988202..7b7812fd 100644 --- a/tests/dataloader/test_packed_dataset.py +++ b/tests/dataloader/test_packed_dataset.py @@ -1,6 +1,7 @@ import json from pathlib import Path +import numpy as np import pytest from modalities.dataloader.create_packed_data import EmbeddedStreamData, PackedDataGenerator, join_embedded_stream_data @@ -111,7 +112,7 @@ def test_create_packed_dataset(indexed_dummy_data_path_long, wrapped_gpt2_tokeni assert not default_packed_dataset_path.is_file() packed_generator.run() packed_dataset = PackedMemMapDatasetContinuous( - default_packed_dataset_path, block_size=block_size, sample_key="input_ids" + default_packed_dataset_path, block_size=block_size, sample_key="input_ids", load_index=True ) # read in the raw jsonl files for manual tokenization @@ -241,3 +242,20 @@ def test_original_samples_in_packed_dataset(indexed_dummy_data_path_long, wrappe for sample, original_sample in zip(packed_dataset, jsonl_tokenized): assert sample["input_ids"].tolist() == original_sample + + +@pytest.mark.parametrize( + "token_size_in_bytes, block_size, total_tokens", [(1, 32, 32), (2, 32, 512), (4, 32, 1000), (4, 32, 1234)] +) +def test_continuously_packed_index(token_size_in_bytes: int, block_size: int, total_tokens: int): + num_samples = (total_tokens - block_size) // (block_size - 1) + 1 + # given num_samples we calculate the starting index and length of each sample as tuple. + result_slow = [ + ((i * block_size - i) * token_size_in_bytes, block_size * token_size_in_bytes) for i in range(num_samples) + ] + + result_vectorized = PackedMemMapDatasetContinuous._create_packed_index( + total_tokens=total_tokens, block_size=block_size, token_size_in_bytes=token_size_in_bytes + ) + + assert np.all(result_slow == result_vectorized) From e1091bf3123594fcff39fcca829e7b35c1bcd62f Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Wed, 25 Sep 2024 16:30:04 +0200 Subject: [PATCH 15/33] chore: removed legacy code from DataloaderFactory --- .../dataloader/dataloader_factory.py | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/src/modalities/dataloader/dataloader_factory.py b/src/modalities/dataloader/dataloader_factory.py index 1f12ca47..8d0424ef 100644 --- a/src/modalities/dataloader/dataloader_factory.py +++ b/src/modalities/dataloader/dataloader_factory.py @@ -26,35 +26,9 @@ def get_dataloader( collate_fn (Callable): Callable for shaping the batch num_workers (int): Number of workers for the dataloader pin_memory (bool): Flag indicating whether to pin memory - skip_num_batches (int, optional): Defines the number of batches to skip. - NOTE: The checkpoints are indexed with training steps (i.e., number of optimizer steps). - skip_num_batches must not be confused with the number of optimizer steps! - skip_num_batches = num optimizer steps * gradient accumulation steps - Defaults to 0. - fixed_num_batches: (int, optional): Fixed length of the dataloader by cutting off subsequent batches. - Note that these are NOT the global number of batches, but the amount of batches that an - individual rank sees. Make sure that the dataloader has at least fixed_num_batches. - Defaults to None. - Returns: LLMDataLoader: Instance of LLMDataLoader """ - - # batch_sampler = ResumableBatchSampler( - # start_index=skip_num_batches, underlying_batch_sampler=batch_sampler, max_num_elements=fixed_num_batches - # ) - - # if fixed_num_batches is not None and fixed_num_batches <= skip_num_batches: - # raise ConfigError("fixed_num_batches must be larger than skip_num_batches") - - # # make sure that the batch sampler has enough elements such - # # that we can fix the number of batches to num_batches - # if fixed_num_batches is not None and len(batch_sampler) < fixed_num_batches - skip_num_batches: - # raise ConfigError( - # f"The dataloader contains only {len(batch_sampler)} batches, which is less than " - # f"specified fixed amount of batches of {fixed_num_batches}." - # ) - dataloader = LLMDataLoader( dataloader_tag=dataloader_tag, batch_sampler=batch_sampler, From 95a121ed849da82894c1267e6596dba7c9770228 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Wed, 25 Sep 2024 22:50:59 +0200 Subject: [PATCH 16/33] refactor: upated configs --- config_files/training/config_example_coca.yaml | 14 +++++++------- config_files/training/config_lorem_ipsum.yaml | 4 ++-- tutorials/library_usage/config_lorem_ipsum.yaml | 13 +++++++------ 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/config_files/training/config_example_coca.yaml b/config_files/training/config_example_coca.yaml index 570f9e5a..be9060ee 100644 --- a/config_files/training/config_example_coca.yaml +++ b/config_files/training/config_example_coca.yaml @@ -46,7 +46,7 @@ settings: training_progress: global_num_seen_tokens: 0 num_seen_steps: 0 - local_num_seen_batches: 0 + num_seen_samples: 0 last_step: -1 coca_example_settings: train_num_samples: 64 @@ -96,7 +96,6 @@ train_dataloader: num_workers: 2 pin_memory: true dataloader_tag: train - skip_num_batches: ${settings.training_progress.local_num_seen_batches} dataset: instance_key: train_dataset pass_type: BY_REFERENCE @@ -108,16 +107,17 @@ train_dataloader: drop_last: true sampler: component_key: sampler - variant_key: distributed_sampler + variant_key: resumable_distributed_sampler config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE rank: ${settings.cuda_env.global_rank} num_replicas: ${settings.cuda_env.world_size} shuffle: true - drop_last: true seed: 42 - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} collate_fn: instance_key: collate_fn pass_type: BY_REFERENCE diff --git a/config_files/training/config_lorem_ipsum.yaml b/config_files/training/config_lorem_ipsum.yaml index ffa43f2d..c9ad4288 100644 --- a/config_files/training/config_lorem_ipsum.yaml +++ b/config_files/training/config_lorem_ipsum.yaml @@ -47,7 +47,7 @@ settings: training_progress: global_num_seen_tokens: 0 num_seen_steps: 0 - local_num_seen_batches: 0 + num_seen_samples: 0 last_step: -1 collate_fn: @@ -93,7 +93,7 @@ train_dataloader: shuffle: true seed: 42 drop_last: true - skip_num_global_samples: 0 + skip_num_global_samples: ${settings.training_progress.num_seen_samples} collate_fn: instance_key: collate_fn pass_type: BY_REFERENCE diff --git a/tutorials/library_usage/config_lorem_ipsum.yaml b/tutorials/library_usage/config_lorem_ipsum.yaml index bd7cd59c..915e0ebd 100644 --- a/tutorials/library_usage/config_lorem_ipsum.yaml +++ b/tutorials/library_usage/config_lorem_ipsum.yaml @@ -48,7 +48,7 @@ settings: training_progress: global_num_seen_tokens: 0 num_seen_steps: 0 - local_num_seen_batches: 0 + num_seen_samples: 0 last_step: -1 tokenizer: @@ -101,16 +101,17 @@ train_dataloader: drop_last: true sampler: component_key: sampler - variant_key: distributed_sampler + variant_key: resumable_distributed_sampler config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE rank: ${settings.cuda_env.global_rank} num_replicas: ${settings.cuda_env.world_size} shuffle: true - drop_last: true seed: 42 - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} collate_fn: instance_key: collate_fn pass_type: BY_REFERENCE From 611c77b16bef88096e39d33295bc3b5690a49958 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Wed, 25 Sep 2024 22:51:35 +0200 Subject: [PATCH 17/33] feat: added number conversion routine --- src/modalities/dataloader/samplers.py | 10 ++++++++-- src/modalities/registry/components.py | 7 +++++++ src/modalities/utils/number_conversion.py | 19 +++++++++++++++++++ tutorials/getting_started/example_config.yaml | 13 +++++++------ 4 files changed, 41 insertions(+), 8 deletions(-) diff --git a/src/modalities/dataloader/samplers.py b/src/modalities/dataloader/samplers.py index cb963a4c..2fc4c9ce 100644 --- a/src/modalities/dataloader/samplers.py +++ b/src/modalities/dataloader/samplers.py @@ -19,6 +19,10 @@ def __init__( underlying_batch_sampler (BatchSampler): Sampler providing the batch ids. max_num_elements (Optional[int]): The maximum number of elements the sampler returns. Default None. + Warning: During instantiation the indices are computed and stored in memory. This is needed for skipping + and is very costly with large datasets, leading to long delays until the training starts. + In this case, it is recommended to use the `ResumableDistributedSampler` instead. + Returns: None """ @@ -26,8 +30,10 @@ def __init__( self.start_index = start_index self.max_num_elements = max_num_elements self.underlying_batch_sampler = underlying_batch_sampler - # NOTE: we are only iterating ove the indices not the actual data - # so this is relatively cheap + # NOTE: we are only iterating over the indices not the actual data + # so this is relatively cheap for small datasets. + # For large-scale datasets in the range of billions to trillion samples, this can be very costly + # and delay the training start. self.indices = list(iter(self.underlying_batch_sampler)) # We discard the samples that come after max_num_elements # NOTE, that skipping is implemented in __iter__ and __len__. diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 89079f04..6059b9a5 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -90,6 +90,7 @@ LocalNumBatchesFromNumTokensConfig, NumberConversion, NumberConversionFromCheckpointPathConfig, + NumSamplesFromNumTokensConfig, NumStepsFromNumSamplesConfig, NumStepsFromNumTokensConfig, NumStepsFromRawDatasetIndexConfig, @@ -262,6 +263,12 @@ class ComponentEntity: NumberConversion.get_local_num_batches_from_num_tokens, LocalNumBatchesFromNumTokensConfig, ), + ComponentEntity( + "number_conversion", + "num_samples_from_num_tokens", + NumberConversion.get_num_samples_from_num_tokens, + NumSamplesFromNumTokensConfig, + ), ComponentEntity( "number_conversion", "num_steps_from_num_samples", diff --git a/src/modalities/utils/number_conversion.py b/src/modalities/utils/number_conversion.py index c8348af0..3dc56732 100644 --- a/src/modalities/utils/number_conversion.py +++ b/src/modalities/utils/number_conversion.py @@ -20,6 +20,11 @@ class LocalNumBatchesFromNumTokensConfig(BaseModel): local_micro_batch_size: Annotated[int, Field(strict=True, gt=0)] +class NumSamplesFromNumTokensConfig(BaseModel): + num_tokens: Annotated[int, Field(strict=True, ge=0)] + sequence_length: Annotated[int, Field(strict=True, gt=0)] + + class NumStepsFromNumSamplesConfig(BaseModel): num_ranks: Annotated[int, Field(strict=True, gt=0)] local_micro_batch_size: Annotated[int, Field(strict=True, gt=0)] @@ -98,6 +103,20 @@ def get_local_num_batches_from_num_samples( """ return (global_num_samples) // num_ranks // local_micro_batch_size + @staticmethod + def get_num_samples_from_num_tokens(num_tokens: int, sequence_length: int) -> int: + """Calculates the number of samples given the global number of tokens and sequence length. + + Args: + num_tokens (int): Global number of tokens. + sequence_length (int): Sequence length of the model. + + Returns: + int: Number of samples. + """ + num_samples = num_tokens // sequence_length + return num_samples + @staticmethod def get_local_num_batches_from_num_tokens( num_ranks: int, global_num_tokens: int, sequence_length: int, local_micro_batch_size: int diff --git a/tutorials/getting_started/example_config.yaml b/tutorials/getting_started/example_config.yaml index ee68d9ec..f1737f94 100644 --- a/tutorials/getting_started/example_config.yaml +++ b/tutorials/getting_started/example_config.yaml @@ -48,7 +48,7 @@ settings: training_progress: global_num_seen_tokens: 0 num_seen_steps: 0 - local_num_seen_batches: 0 + num_seen_samples: 0 last_step: -1 collate_fn: @@ -73,7 +73,6 @@ train_dataloader: num_workers: 2 pin_memory: true dataloader_tag: train - skip_num_batches: ${settings.training_progress.local_num_seen_batches} dataset: instance_key: train_dataset pass_type: BY_REFERENCE @@ -85,15 +84,17 @@ train_dataloader: drop_last: true sampler: component_key: sampler - variant_key: distributed_sampler + variant_key: resumable_distributed_sampler config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE rank: ${settings.cuda_env.global_rank} num_replicas: ${settings.cuda_env.world_size} shuffle: true + seed: 42 drop_last: true - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE + skip_num_global_samples: ${settings.training_progress.num_seen_samples} collate_fn: instance_key: collate_fn pass_type: BY_REFERENCE From 4a759f3659bca6211d1c92b8b1245573cc96497a Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Wed, 25 Sep 2024 23:13:24 +0200 Subject: [PATCH 18/33] chore: updated tutorial configs --- .../configs/pretraining_config.yaml | 14 +++++++------- .../warmstart/configs/pre_training_config.yaml | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tutorials/modalities_in_15_mins/configs/pretraining_config.yaml b/tutorials/modalities_in_15_mins/configs/pretraining_config.yaml index d4c7ec2f..166d25fb 100644 --- a/tutorials/modalities_in_15_mins/configs/pretraining_config.yaml +++ b/tutorials/modalities_in_15_mins/configs/pretraining_config.yaml @@ -47,7 +47,7 @@ settings: training_progress: global_num_seen_tokens: 0 num_seen_steps: 0 - local_num_seen_batches: 0 + num_seen_samples: 0 last_step: -1 collate_fn: @@ -72,7 +72,6 @@ train_dataloader: num_workers: 2 pin_memory: true dataloader_tag: train - skip_num_batches: ${settings.training_progress.local_num_seen_batches} dataset: instance_key: train_dataset pass_type: BY_REFERENCE @@ -84,16 +83,17 @@ train_dataloader: drop_last: true sampler: component_key: sampler - variant_key: distributed_sampler + variant_key: resumable_distributed_sampler config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE rank: ${settings.cuda_env.global_rank} num_replicas: ${settings.cuda_env.world_size} shuffle: true - drop_last: true seed: 42 - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} collate_fn: instance_key: collate_fn pass_type: BY_REFERENCE diff --git a/tutorials/warmstart/configs/pre_training_config.yaml b/tutorials/warmstart/configs/pre_training_config.yaml index 9bf1de87..30db4adf 100644 --- a/tutorials/warmstart/configs/pre_training_config.yaml +++ b/tutorials/warmstart/configs/pre_training_config.yaml @@ -47,7 +47,7 @@ settings: training_progress: global_num_seen_tokens: 0 num_seen_steps: 0 - local_num_seen_batches: 0 + num_seen_samples: 0 last_step: -1 collate_fn: @@ -72,7 +72,6 @@ train_dataloader: num_workers: 2 pin_memory: true dataloader_tag: train - skip_num_batches: ${settings.training_progress.local_num_seen_batches} dataset: instance_key: train_dataset pass_type: BY_REFERENCE @@ -84,16 +83,17 @@ train_dataloader: drop_last: true sampler: component_key: sampler - variant_key: distributed_sampler + variant_key: resumable_distributed_sampler config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE rank: ${settings.cuda_env.global_rank} num_replicas: ${settings.cuda_env.world_size} shuffle: true - drop_last: true seed: 42 - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} collate_fn: instance_key: collate_fn pass_type: BY_REFERENCE From b5cc6177ac4bc8eb2f061fa15f3f7680c73c2b10 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Wed, 25 Sep 2024 23:20:36 +0200 Subject: [PATCH 19/33] refactor: removed obsolete test test_dataloader_with_fixed_num_batches --- tests/dataloader/test_dataloader.py | 59 +------------------ .../dataloader_with_fixed_num_batches.yaml | 56 ------------------ 2 files changed, 1 insertion(+), 114 deletions(-) delete mode 100644 tests/dataloader/yaml_configs/dataloader_with_fixed_num_batches.yaml diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 3b0ef7be..f2241297 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -1,9 +1,8 @@ from collections.abc import Sequence from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict import numpy as np -import pytest import torch from pydantic import BaseModel from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler @@ -14,7 +13,6 @@ from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader from modalities.dataloader.dataset import Dataset from modalities.dataloader.samplers import ResumableBatchSampler -from modalities.models.gpt2.collator import CollateFnIF from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry @@ -236,58 +234,3 @@ class DataloaderTestModel(BaseModel): for batch_1, batch_2 in zip(batches_rank_0, batches_rank_1): assert ~(batch_1.samples["input_ids"] == batch_2.samples["input_ids"]).all() assert ~(batch_1.targets["target_ids"] == batch_2.targets["target_ids"]).all() - - -@pytest.mark.parametrize( - "global_rank", - [0, 1], -) -def test_dataloader_with_fixed_num_batches(global_rank): - class DataloaderTestModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType - fixed_num_batches: int - - class IdentityCollateFn(CollateFnIF): - def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> List[Dict[str, torch.Tensor]]: - return batch - - root_dir = Path(__file__).parents[0] - - config_path = root_dir / "yaml_configs/dataloader_with_fixed_num_batches.yaml" - # we inject a prebuilt dataset and collate_fn, as well as, the global rank constant from outside - dataset = SequenceDataset(list(range(1000))) - config_dict = load_app_config_dict(config_path) - config_dict["settings"]["cuda_env"]["global_rank"] = global_rank - config_dict["train_dataloader"]["config"]["batch_sampler"]["config"]["sampler"]["config"]["rank"] = global_rank - config_dict["train_dataset"] = dataset - config_dict["collate_fn"] = IdentityCollateFn() - - # build the remaining components - registry = Registry(COMPONENTS) - component_factory = ComponentFactory(registry=registry) - components: DataloaderTestModel = component_factory.build_components( - config_dict=config_dict, components_model_type=DataloaderTestModel - ) - dataloader = components.train_dataloader - - # calculate the fixed_num_batches and - # compare it with the one calculated during the component build and the dataloader length - cfg = config_dict["settings"]["training"] - world_size = config_dict["settings"]["cuda_env"]["world_size"] - calculated_fixed_num_batches = ( - cfg["global_num_train_tokens"] // cfg["local_train_micro_batch_size"] // cfg["sequence_length"] // world_size - ) - assert calculated_fixed_num_batches == components.fixed_num_batches - assert len(dataloader) == calculated_fixed_num_batches - - # We make sure that the dataloader outputs the correct batches as follows: - # The dataset contains 1000 samples (NOTE that we neglected squence_length and made each sample an integer value) - # we calculated 8 batches above per rank and have 2 ranks in total. - # Therefore the dataloader for rank 0 returns 8 ordered batches of batch_size 2. - # The batches are ordered and not shuffled as per YAML configuration. - # We expect the following output: - # [[0, 2], [4, 6], [8, 10], ..., [24, 26], [28, 30]] (global_rank=0) - # [[1, 3], [5, 7], [9, 11], ..., [25, 27], [29, 31]] (global_rank=1) - calculated_dataloader_content = np.array(list(range(global_rank, 32 + global_rank, 2))).reshape(-1, 2).tolist() - actual_dataloader_content = [i for i in dataloader] - assert calculated_dataloader_content == actual_dataloader_content diff --git a/tests/dataloader/yaml_configs/dataloader_with_fixed_num_batches.yaml b/tests/dataloader/yaml_configs/dataloader_with_fixed_num_batches.yaml deleted file mode 100644 index cb2385f5..00000000 --- a/tests/dataloader/yaml_configs/dataloader_with_fixed_num_batches.yaml +++ /dev/null @@ -1,56 +0,0 @@ -# NOTE, settings is not type checked in the instantiation model (specified within the test), as the settings are not used in the pydantic model. -# Therefore, we can place arbitrary values in the settings field. -# Only train_dataloader and fixed_num_batches are type checked in the instantiation model. - -settings: - training: - local_train_micro_batch_size: 2 - global_num_train_tokens: 128 - sequence_length: 4 - cuda_env: - global_rank: 0 - world_size: 2 - -fixed_num_batches: - component_key: number_conversion - variant_key: local_num_batches_from_num_tokens - config: - num_ranks: ${settings.cuda_env.world_size} - global_num_tokens: ${settings.training.global_num_train_tokens} - sequence_length: ${settings.training.sequence_length} - local_micro_batch_size: ${settings.training.local_train_micro_batch_size} - -train_dataloader: - component_key: data_loader - variant_key: default - config: - num_workers: 2 - pin_memory: true - dataloader_tag: train - skip_num_batches: 0 - fixed_num_batches: - instance_key: fixed_num_batches - pass_type: BY_REFERENCE - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: ${settings.training.local_train_micro_batch_size} - drop_last: true - sampler: - component_key: sampler - variant_key: distributed_sampler - config: - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - drop_last: true - shuffle: false - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE From 053275e1357c5014f56b2b72bc3b8f57f19a4c3b Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Wed, 25 Sep 2024 23:38:28 +0200 Subject: [PATCH 20/33] refactor: adapted more failing test to the dataloader changes --- config_files/training/config_lorem_ipsum.yaml | 2 +- src/modalities/config/instantiation_models.py | 2 +- .../test_yaml_configs/config_lorem_ipsum.yaml | 20 +++++++++---------- tests/utils/test_number_conversion.py | 4 ++-- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/config_files/training/config_lorem_ipsum.yaml b/config_files/training/config_lorem_ipsum.yaml index c9ad4288..670610da 100644 --- a/config_files/training/config_lorem_ipsum.yaml +++ b/config_files/training/config_lorem_ipsum.yaml @@ -26,7 +26,7 @@ settings: local_train_micro_batch_size: 1 sequence_length: 256 training_target: - num_target_tokens: + num_target_tokens: component_key: number_conversion variant_key: num_tokens_from_packed_mem_map_dataset_continuous config: diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index 690559b3..533aa3f9 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -56,7 +56,7 @@ class TrainingTarget(BaseModel): class TrainingProgress(BaseModel): global_num_seen_tokens: Annotated[int, Field(strict=True, ge=0)] num_seen_steps: Annotated[int, Field(strict=True, ge=0)] - local_num_seen_batches: Annotated[int, Field(strict=True, ge=0)] + num_seen_samples: Annotated[int, Field(strict=True, ge=0)] last_step: Annotated[int, Field(strict=True, ge=-1)] diff --git a/tests/test_yaml_configs/config_lorem_ipsum.yaml b/tests/test_yaml_configs/config_lorem_ipsum.yaml index 96f2f7db..e9552785 100644 --- a/tests/test_yaml_configs/config_lorem_ipsum.yaml +++ b/tests/test_yaml_configs/config_lorem_ipsum.yaml @@ -46,10 +46,10 @@ settings: training_progress: global_num_seen_tokens: 0 num_seen_steps: 0 - local_num_seen_batches: 0 + num_seen_samples: 0 last_step: -1 -collate_fn: +collate_fn: component_key: collate_fn variant_key: gpt_2_llm_collator config: @@ -71,7 +71,6 @@ train_dataloader: num_workers: 2 pin_memory: true dataloader_tag: train - skip_num_batches: ${settings.training_progress.local_num_seen_batches} dataset: instance_key: train_dataset pass_type: BY_REFERENCE @@ -83,16 +82,17 @@ train_dataloader: drop_last: true sampler: component_key: sampler - variant_key: distributed_sampler + variant_key: resumable_distributed_sampler config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE rank: ${settings.cuda_env.global_rank} num_replicas: ${settings.cuda_env.world_size} shuffle: true - drop_last: true seed: 42 - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} collate_fn: instance_key: collate_fn pass_type: BY_REFERENCE @@ -103,7 +103,7 @@ val_dataloader: config: num_workers: 2 pin_memory: true - dataloader_tag: "val" + dataloader_tag: val dataset: instance_key: train_dataset pass_type: BY_REFERENCE @@ -134,7 +134,7 @@ test_dataloader: config: num_workers: 2 pin_memory: true - dataloader_tag: "test" + dataloader_tag: test dataset: instance_key: train_dataset pass_type: BY_REFERENCE diff --git a/tests/utils/test_number_conversion.py b/tests/utils/test_number_conversion.py index 9cc1470d..f54d807a 100644 --- a/tests/utils/test_number_conversion.py +++ b/tests/utils/test_number_conversion.py @@ -360,8 +360,8 @@ def test_num_steps_from_raw_dataset_index( num_ranks: int, local_micro_batch_size: int, gradient_accumulation_steps: int ): working_dir = Path(__file__).parent - raw_dataset_path = working_dir / "../data/datasets/lorem_ipsum_long.jsonl" - raw_index_path = working_dir / "../data/datasets/lorem_ipsum_long.idx" + raw_dataset_path = working_dir / "../../data/lorem_ipsum_long.jsonl" + raw_index_path = working_dir / "../../data/lorem_ipsum_long.idx" with open(raw_dataset_path, "r") as f: num_samples = len(f.readlines()) From e5ee5e90a5f1f2d209dc50800281435b50cfaa9c Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Thu, 26 Sep 2024 21:49:30 +0200 Subject: [PATCH 21/33] refactor: removed RepeatingDataLoader --- src/modalities/config/config.py | 11 -- src/modalities/dataloader/dataloader.py | 161 +----------------- .../dataloader/dataloader_factory.py | 23 +-- src/modalities/registry/components.py | 4 - tests/dataloader/samplers/test_samplers.py | 25 --- tests/dataloader/test_dataloader.py | 151 +++------------- 6 files changed, 28 insertions(+), 347 deletions(-) delete mode 100644 tests/dataloader/samplers/test_samplers.py diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 8672f391..a1a81923 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -311,11 +311,6 @@ class BatchSamplerConfig(BaseModel): drop_last: Literal[True] = True -class ResumableBatchSamplerConfig(BaseModel): - sampler: PydanticSamplerIFType - start_index: Annotated[int, Field(strict=True, gt=0)] - - class GPT2LLMCollateFnConfig(BaseModel): sample_key: str target_key: str @@ -330,12 +325,6 @@ class LLMDataLoaderConfig(BaseModel): pin_memory: bool -class RepeatingDataLoaderConfig(BaseModel): - dataloader: PydanticLLMDataLoaderIFType - reshuffle_after_epoch: Optional[bool] = False - num_epochs: Annotated[int, Field(strict=True, ge=1)] - - class DummyProgressSubscriberConfig(BaseModel): pass diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index fbf5cc36..27d9bec3 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -1,10 +1,8 @@ from typing import Iterable, Optional, Union -from torch.utils.data import Dataset, DistributedSampler, Sampler +from torch.utils.data import BatchSampler, Dataset, Sampler from torch.utils.data.dataloader import DataLoader, T_co, _collate_fn_t, _worker_init_fn_t -from modalities.dataloader.samplers import ResumableBatchSampler - class LLMDataLoader(DataLoader[T_co]): """LLMDataLoader is a custom DataLoader class that extends the PyTorch DataLoader class.""" @@ -12,7 +10,7 @@ class LLMDataLoader(DataLoader[T_co]): def __init__( self, dataloader_tag: str, - batch_sampler: ResumableBatchSampler, + batch_sampler: BatchSampler, dataset: Dataset[T_co], batch_size: Optional[int] = 1, sampler: Union[Sampler, Iterable, None] = None, @@ -34,7 +32,7 @@ def __init__( Args: dataloader_tag (str): The tag for the dataloader. - batch_sampler (ResumableBatchSampler): The batch sampler used for sampling batches. + batch_sampler (BatchSampler): The batch sampler used for sampling batches. dataset (Dataset[T_co]): The dataset to load the data from. batch_size (Optional[int], optional): The number of samples per batch. Defaults to 1. sampler (Union[Sampler, Iterable, None], optional): The sampler used for sampling data. Defaults to None. @@ -77,7 +75,6 @@ def __init__( ) self._dataloader_tag = dataloader_tag - self._batch_size = batch_sampler.batch_size @property def dataloader_tag(self) -> str: @@ -88,155 +85,3 @@ def dataloader_tag(self) -> str: str: The dataloader tag. """ return self._dataloader_tag - - @property - def batch_size(self) -> int: - """ - Returns the batch size used in the dataloader. - The batch size is the number of samples in each batch of data. - - Returns: - int: The batch size used in the dataloader. - - Note: - The parent Dataloader class has already a batch_size property defined which is originally used - when the batch_sampler is not specified. Since the LLMDataLoader enforces to always use a BatchSampler, - we defined/ override the property batch_size to return the actual batch size used in the dataloder. - BatchSampler is required, as we must seek forward in the dataloder during a warm start and - we don't want to load all the data during the fast-forward. - """ - return self._batch_size - - @batch_size.setter - def batch_size(self, value: int): - """ - Set the batch size for the dataloader. - - Parameters: - value (int): The batch size to be set. - - Returns: - None - """ - self._batch_size = value - - @property - def fast_forward_batch_id(self) -> int: - """ - The batch ID until which we fast-forward, as specified in the ResumableBatchSampler. - - Returns: - int: fast forward batch ID - """ - return self.batch_sampler.start_index - - -class RepeatingDataLoader(LLMDataLoader[T_co]): - """ - RepeatingDataLoader is a custom DataLoader class that repeats the given dataloader - for the specified number of epochs.""" - - def __init__(self, dataloader: LLMDataLoader[T_co], num_epochs: int, reshuffle_after_epoch: bool = False): - """ - Initializes a RepeatingDataLoader object that repeats the given dataloader for the specified number of epochs. - This is especially useful for DataLoader types that we wish to automatically restart upon completion. - - Args: - dataloader (LLMDataLoader[T_co]): The dataloader to be wrapped. - num_epochs (int): The number of epochs to iterate over the dataloader. - reshuffle_after_epoch (bool, optional): Flag indicating whether to reshuffle the dataloader - after each epoch. Defaults to False. - - Returns: - None - - Note: - Based on: https://github.com/microsoft/DeepSpeed/blob/99951caa3d2155a3bb84109a0828543793e088cc/deepspeed/runtime/dataloader.py#L17 - """ - self.dataloader = dataloader - self.data_iter = iter(self.dataloader) - self.current_epoch = 0 - self.reshuffle_after_epoch = reshuffle_after_epoch - self.num_epochs = num_epochs - - def __iter__(self): - """ - Returns an iterator object for the DataLoader. - """ - return self - - def __next__(self): - """ - Returns the next batch of data from the DataLoader. - - Raises: - StopIteration: If there are no more batches of data to return. - - Returns: - batch: The next batch of data. - """ - try: - batch = next(self.data_iter) - except StopIteration as e: - if self.dataloader.sampler is not None: - self.current_epoch += 1 - # After finishing an epoch, we set the start_index to 0 to start from the beginning - # The start_index might have been >0 in case of a warm start - self.dataloader.batch_sampler.start_index = 0 - - if self.reshuffle_after_epoch: - # In distributed mode, calling the set_epoch() method at the beginning of each epoch before creating - # the DataLoader iterator is necessary to make shuffling work properly across multiple epochs. - # Otherwise, the same ordering will be always used. See discussion: - # https://discuss.pytorch.org/t/why-is-sampler-set-epoch-epoch-needed-for-distributedsampler/149672 - if isinstance(self.dataloader.sampler, DistributedSampler): - self.dataloader.sampler.set_epoch(self.current_epoch) - else: - raise NotImplementedError( - "Reshuffling after each epoch is only supported for DistributedSampler" - ) - if self.current_epoch < self.num_epochs: - self.data_iter = iter(self.dataloader) - batch = next(self.data_iter) - else: - raise StopIteration(f"RepeatingDataLoader has completed after {self.current_epoch} epochs") from e - return batch - - @property - def dataloader_tag(self) -> str: - """ - Returns the dataloader tag. - - Returns: - str: The dataloader tag. - """ - return self.dataloader.dataloader_tag - - @property - def batch_size(self) -> int: - """ - Returns the batch size used by the dataloader. - - Returns: - int: The batch size used by the dataloader. - """ - return self.dataloader.batch_size - - @property - def fast_forward_batch_id(self) -> int: - """ - The batch ID until which we fast-forward, as specified in the ResumableBatchSampler. - - Returns: - int: fast forward batch id - """ - return self.dataloader.fast_forward_batch_id - - def __len__(self) -> int: - """ - Returns the total number of steps in the dataloader. - - Returns: - int: The total number of steps. - """ - return self.num_epochs * len(self.dataloader) diff --git a/src/modalities/dataloader/dataloader_factory.py b/src/modalities/dataloader/dataloader_factory.py index 8d0424ef..56d9db1b 100644 --- a/src/modalities/dataloader/dataloader_factory.py +++ b/src/modalities/dataloader/dataloader_factory.py @@ -3,7 +3,7 @@ from torch.utils.data import BatchSampler from torch.utils.data.dataset import Dataset -from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader +from modalities.dataloader.dataloader import LLMDataLoader class DataloaderFactory: @@ -39,24 +39,3 @@ def get_dataloader( ) return dataloader - - @staticmethod - def get_repeating_dataloader( - dataloader: LLMDataLoader, num_epochs: int, reshuffle_after_epoch: bool = False - ) -> RepeatingDataLoader: - """ - Returns a RepeatingDataLoader object that repeats the given dataloader - for the specified number of epochs. - - Parameters: - dataloader (LLMDataLoader): The dataloader to be repeated. - num_epochs (int): The number of times the dataloader should be repeated. - reshuffle_after_epoch (bool, optional): Flag indicating whether to reshuffle - the data after each epoch. Defaults to False. - - Returns: - RepeatingDataLoader: A RepeatingDataLoader object that repeats the given dataloader - for the specified number of epochs. - """ - dataloader = RepeatingDataLoader(dataloader, num_epochs, reshuffle_after_epoch) - return dataloader diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 6059b9a5..ecda8699 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -41,7 +41,6 @@ PackedMemMapDatasetMegatronConfig, PreTrainedHFTokenizerConfig, PreTrainedSPTokenizerConfig, - RepeatingDataLoaderConfig, ResumableDistributedSamplerConfig, RichProgressSubscriberConfig, RichResultSubscriberConfig, @@ -192,9 +191,6 @@ class ComponentEntity: ComponentEntity("collate_fn", "coca_collator", CoCaCollatorFn, CoCaCollateFnConfig), # data loaders ComponentEntity("data_loader", "default", DataloaderFactory.get_dataloader, LLMDataLoaderConfig), - ComponentEntity( - "data_loader", "repeating_data_loader", DataloaderFactory.get_repeating_dataloader, RepeatingDataLoaderConfig - ), # checkpointing ComponentEntity("checkpoint_saving", "default", CheckpointSaving, CheckpointSavingConfig), # checkpointing strategies diff --git a/tests/dataloader/samplers/test_samplers.py b/tests/dataloader/samplers/test_samplers.py deleted file mode 100644 index 6d3a73ed..00000000 --- a/tests/dataloader/samplers/test_samplers.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -from torch.utils.data.sampler import BatchSampler - -from modalities.dataloader.samplers import ResumableBatchSampler - - -def test_resumable_sampler(resumable_batch_sampler: ResumableBatchSampler): - existing_sampler: BatchSampler = resumable_batch_sampler.underlying_batch_sampler - indices_1 = [i for i in resumable_batch_sampler] - indices_2 = [i for i in existing_sampler][resumable_batch_sampler.start_index :] - - data_source = existing_sampler.sampler.data_source[resumable_batch_sampler.start_index :] - assert indices_1 == indices_2 - assert indices_1 != data_source - - -def test_resumable_batch_sampler(resumable_batch_sampler: ResumableBatchSampler): - underlying_batch_sampler: BatchSampler = resumable_batch_sampler.underlying_batch_sampler - values_1 = [i for i in resumable_batch_sampler] - - values_2_flat = underlying_batch_sampler.sampler.data_source[::-1][ - underlying_batch_sampler.batch_size * resumable_batch_sampler.start_index : - ] - values_2 = torch.IntTensor(values_2_flat).reshape([-1, underlying_batch_sampler.batch_size]).tolist() - assert values_1 == values_2 diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index f2241297..8e264100 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -1,18 +1,17 @@ from collections.abc import Sequence from pathlib import Path -from typing import Any, Dict +from typing import Any import numpy as np import torch from pydantic import BaseModel -from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler +from torch.utils.data import BatchSampler, SequentialSampler from modalities.config.component_factory import ComponentFactory from modalities.config.config import load_app_config_dict from modalities.config.pydanctic_if_types import PydanticLLMDataLoaderIFType -from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader +from modalities.dataloader.dataloader import LLMDataLoader from modalities.dataloader.dataset import Dataset -from modalities.dataloader.samplers import ResumableBatchSampler from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry @@ -31,142 +30,33 @@ def __getitem__(self, idx: int) -> Any: def test_resumable_dataloader(): batch_size = 3 - start_index = 2 dataset = list(range(12))[::-1] seq_sampler = SequentialSampler(data_source=dataset) batch_sampler = BatchSampler(sampler=seq_sampler, batch_size=batch_size, drop_last=False) - resumable_batch_sampler = ResumableBatchSampler(underlying_batch_sampler=batch_sampler, start_index=start_index) - dataloader = LLMDataLoader(dataloader_tag="train", dataset=dataset, batch_sampler=resumable_batch_sampler) + dataloader = LLMDataLoader(dataloader_tag="train", dataset=dataset, batch_sampler=batch_sampler) flat_samples = torch.cat([i for i in dataloader]) - original_samples = torch.IntTensor(dataset[start_index * batch_size :]) + original_samples = torch.IntTensor(dataset) assert (flat_samples == original_samples).all() -def test_dataloader_from_config(dummy_config: Dict): - start_index = 2 - dummy_config["train_dataloader"]["config"]["skip_num_batches"] = start_index - - class DataloaderTestModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType - - registry = Registry(COMPONENTS) - component_factory = ComponentFactory(registry=registry) - components: DataloaderTestModel = component_factory.build_components( - config_dict=dummy_config, components_model_type=DataloaderTestModel - ) - - dataloader_1: LLMDataLoader = components.train_dataloader - dataset = dataloader_1.dataset - resumable_batch_sampler: ResumableBatchSampler = dataloader_1.batch_sampler - distributed_sampler = resumable_batch_sampler.underlying_batch_sampler.sampler - batch_sampler = BatchSampler(sampler=distributed_sampler, batch_size=dataloader_1.batch_size, drop_last=True) - dataloader_2 = LLMDataLoader( - dataloader_tag="train", dataset=dataset, batch_sampler=batch_sampler, collate_fn=dataloader_1.collate_fn - ) - - samples_1 = [batch for _, batch in zip(range(10), dataloader_1)] - samples_2 = [batch for _, batch in zip(range(10), dataloader_2)] - - assert len(dataloader_2) == len(dataset) // dataloader_1.batch_size - - assert len(dataloader_1) + start_index == len(dataloader_2) - - for batch_1, batch_2 in zip(samples_1, samples_2): - assert ~(batch_1.samples["input_ids"] == batch_2.samples["input_ids"]).all() - assert ~(batch_1.targets["target_ids"] == batch_2.targets["target_ids"]).all() - - for batch_1, batch_2 in zip(samples_1, samples_2[start_index:]): - assert (batch_1.samples["input_ids"] == batch_2.samples["input_ids"]).all() - assert (batch_1.targets["target_ids"] == batch_2.targets["target_ids"]).all() - - def test_dataloader_batching(): batch_size = 2 - skip_num_batches = 2 dataset = list(range(10)) seq_sampler = SequentialSampler(data_source=dataset) batch_sampler = BatchSampler(sampler=seq_sampler, batch_size=batch_size, drop_last=False) - # the LLMDataLoader always requires a ResumableBatchSampler - resumable_batch_sampler = ResumableBatchSampler( - underlying_batch_sampler=batch_sampler, start_index=skip_num_batches - ) - dataloader = LLMDataLoader(dataloader_tag="train", dataset=dataset, batch_sampler=resumable_batch_sampler) + dataloader = LLMDataLoader(dataloader_tag="train", dataset=dataset, batch_sampler=batch_sampler) batches_1 = torch.stack([i for i in dataloader]) batches_2 = torch.stack([i for i in dataloader]) assert batches_1.equal(batches_2) - assert batches_1.flatten().tolist() == dataset[skip_num_batches * batch_size :] - - -def test_repeating_dataloader_without_shuffling(): - batch_size = 2 - skip_num_batches = 2 - num_samples = 10 - dataset = list(range(num_samples)) - seq_sampler = SequentialSampler(data_source=dataset) - # the LLMDataLoader always requires a ResumableBatchSampler - # create the dataloader that skips the first skip_num_batches - batch_sampler_skipped = BatchSampler(sampler=seq_sampler, batch_size=batch_size, drop_last=True) - resumable_batch_sampler_skipped = ResumableBatchSampler( - underlying_batch_sampler=batch_sampler_skipped, start_index=skip_num_batches - ) - dataloader_skipped = LLMDataLoader( - dataloader_tag="train", dataset=dataset, batch_sampler=resumable_batch_sampler_skipped - ) - - # create dataloader that skips no batches - batch_sampler = BatchSampler(sampler=seq_sampler, batch_size=batch_size, drop_last=True) - resumable_batch_sampler = ResumableBatchSampler(underlying_batch_sampler=batch_sampler, start_index=0) - dataloader = LLMDataLoader(dataloader_tag="train", dataset=dataset, batch_sampler=resumable_batch_sampler) - - # create repeating dataloader that first skips the skip_num_batches - # in epoch 0 and then returns the batches from the beginning - repeating_dataloader = RepeatingDataLoader(dataloader=dataloader_skipped, reshuffle_after_epoch=False, num_epochs=2) - - num_samples // batch_size - # get the batches for two epochs - batches_1 = torch.stack([i for i in dataloader_skipped] + [i for i in dataloader]) - batches_2 = torch.stack([i for i in repeating_dataloader]) - - assert batches_1.equal(batches_2) - assert batches_1.flatten().tolist() == dataset[skip_num_batches * batch_size :] + dataset - - -def test_repeating_dataloader_with_shuffling(): - batch_size = 2 - skip_num_batches = 2 - num_samples = 10 - dataset = list(range(num_samples)) - - generator = torch.Generator().manual_seed(42) - random_sampler = RandomSampler(data_source=dataset, generator=generator) - batch_sampler = BatchSampler(sampler=random_sampler, batch_size=batch_size, drop_last=False) - - # create dataloader that skips not batches - resumable_batch_sampler = ResumableBatchSampler( - underlying_batch_sampler=batch_sampler, start_index=skip_num_batches - ) - dataloader = LLMDataLoader(dataloader_tag="train", dataset=dataset, batch_sampler=resumable_batch_sampler) - - # create repeating dataloader that first skips the skip_num_batches - # in epoch 0 and then returns the batches from the beginning - repeating_dataloader = RepeatingDataLoader(dataloader=dataloader, reshuffle_after_epoch=False, num_epochs=2) - - # get the batches for two epochs - num_batches_per_epoch = num_samples // batch_size - batches = torch.stack([i for i in repeating_dataloader]) - batches_epoch_1 = batches[: num_batches_per_epoch - skip_num_batches] - batches_epoch_2 = batches[num_batches_per_epoch - skip_num_batches :] - # when we skip 2 batches only 3 batches are left, i.e., 6 samples - assert len(set(batches_epoch_1.flatten().tolist())) == 6 - assert set(batches_epoch_2.flatten().tolist()) == set(range(10)) + assert batches_1.flatten().tolist() == dataset def test_skipped_and_distributed_dataloader_from_config(): class DataloaderTestModel(BaseModel): train_dataloader: PydanticLLMDataLoaderIFType - skip_num_batches: int + skip_num_samples: int root_dir = Path(__file__).parents[0] @@ -180,6 +70,13 @@ class DataloaderTestModel(BaseModel): config_dict=config_dict, components_model_type=DataloaderTestModel ) + world_size = config_dict["settings"]["cuda_env"]["world_size"] + local_micro_batch_size = config_dict["settings"]["training"]["local_train_micro_batch_size"] + skip_num_local_batches = components_rank_0.skip_num_samples // world_size // local_micro_batch_size + + assert world_size == 2 + assert skip_num_local_batches == 2 + config_dict["settings"]["cuda_env"]["global_rank"] = 1 config_dict["train_dataloader"]["config"]["batch_sampler"]["config"]["sampler"]["config"]["rank"] = 1 components_rank_1: DataloaderTestModel = component_factory.build_components( @@ -188,26 +85,26 @@ class DataloaderTestModel(BaseModel): dataset = components_rank_0.train_dataloader.dataset - batches_rank_0 = [batch for _, batch in zip(range(10), components_rank_0.train_dataloader)] - batches_rank_1 = [batch for _, batch in zip(range(10), components_rank_1.train_dataloader)] + batches_rank_0 = [batch for batch in components_rank_0.train_dataloader] + batches_rank_1 = [batch for batch in components_rank_1.train_dataloader] # make sure that the dataloaders for the two ranks have the correct number of batches assert ( len(components_rank_0.train_dataloader) - == len(dataset) // 2 // components_rank_0.train_dataloader.batch_size - components_rank_0.skip_num_batches + == (len(dataset) - components_rank_0.skip_num_samples) // world_size // local_micro_batch_size ) assert ( len(components_rank_1.train_dataloader) - == len(dataset) // 2 // components_rank_0.train_dataloader.batch_size - components_rank_0.skip_num_batches + == (len(dataset) - components_rank_1.skip_num_samples) // world_size // local_micro_batch_size ) # we manually build up the batches from each dataloader to compare on a value basis - # with [2:] we skip the first two batches - dataset_indices_rank_0 = np.arange(0, 28, 2).reshape(-1, 2)[2:] - dataset_indices_rank_1 = np.arange(1, 29, 2).reshape(-1, 2)[2:] + # with [skip_num_local_batches:] we skip the first two batches + dataset_indices_rank_0 = np.arange(0, 28, 2).reshape(-1, local_micro_batch_size)[skip_num_local_batches:] + dataset_indices_rank_1 = np.arange(1, 29, 2).reshape(-1, local_micro_batch_size)[skip_num_local_batches:] - assert all((dataset_indices_rank_0 == list(components_rank_0.train_dataloader.batch_sampler)).flatten()) - assert all((dataset_indices_rank_1 == list(components_rank_1.train_dataloader.batch_sampler)).flatten()) + assert np.all((dataset_indices_rank_0 == list(components_rank_0.train_dataloader.batch_sampler))) + assert np.all((dataset_indices_rank_1 == list(components_rank_1.train_dataloader.batch_sampler))) batches_recomputed_rank_0 = [] for batch_indices in dataset_indices_rank_0: From 26b11527d57040f1a66b95ce20f6f924280ddcf7 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Thu, 26 Sep 2024 21:50:26 +0200 Subject: [PATCH 22/33] refactor: removed ResumableBatchSampler --- src/modalities/dataloader/samplers.py | 79 +------------------ tests/conftest.py | 12 --- .../yaml_configs/skipped_dataloader.yaml | 22 +++--- 3 files changed, 10 insertions(+), 103 deletions(-) diff --git a/src/modalities/dataloader/samplers.py b/src/modalities/dataloader/samplers.py index 2fc4c9ce..a53b024d 100644 --- a/src/modalities/dataloader/samplers.py +++ b/src/modalities/dataloader/samplers.py @@ -3,71 +3,7 @@ import torch import torch.distributed as dist -from torch.utils.data import BatchSampler, Dataset, Sampler - - -class ResumableBatchSampler(Sampler): - def __init__( - self, start_index: int, underlying_batch_sampler: BatchSampler, max_num_elements: Optional[int] = None - ): - """ - Sampler which starts at a specified batch index and continues sampling for - for a given sampler. Works with normal samplers and BatchSamplers. - - Args: - start_index (int): index to start sampling from - underlying_batch_sampler (BatchSampler): Sampler providing the batch ids. - max_num_elements (Optional[int]): The maximum number of elements the sampler returns. Default None. - - Warning: During instantiation the indices are computed and stored in memory. This is needed for skipping - and is very costly with large datasets, leading to long delays until the training starts. - In this case, it is recommended to use the `ResumableDistributedSampler` instead. - - Returns: - None - """ - - self.start_index = start_index - self.max_num_elements = max_num_elements - self.underlying_batch_sampler = underlying_batch_sampler - # NOTE: we are only iterating over the indices not the actual data - # so this is relatively cheap for small datasets. - # For large-scale datasets in the range of billions to trillion samples, this can be very costly - # and delay the training start. - self.indices = list(iter(self.underlying_batch_sampler)) - # We discard the samples that come after max_num_elements - # NOTE, that skipping is implemented in __iter__ and __len__. - if self.max_num_elements is not None: - self.indices = self.indices[:max_num_elements] - - def __iter__(self): - """ - Returns an iterator over the indices starting from the start_index. - - Returns: - iterator: An iterator over the indices. - """ - return iter(self.indices[self.start_index :]) - - def __len__(self): - """ - Returns the length of the sampler, which is the number of indices minus the start index. - - Returns: - int: The length of the sampler. - """ - return len(self.indices) - self.start_index - - @property - def batch_size(self) -> int: - """ - Returns the batch size of the underlying batch sampler. - - Returns: - int: The batch size of the underlying batch sampler. - """ - return self.underlying_batch_sampler.batch_size - +from torch.utils.data import Dataset, Sampler T_co = TypeVar("T_co", covariant=True) @@ -194,16 +130,3 @@ def __iter__(self) -> Iterator[T_co]: def __len__(self) -> int: return self.local_num_samples - - def set_epoch(self, epoch: int) -> None: - r""" - Set the epoch for this sampler. - - When :attr:`shuffle=True`, this ensures all replicas - use a different random ordering for each epoch. Otherwise, the next iteration of this - sampler will yield the same ordering. - - Args: - epoch (int): Epoch number. - """ - self.epoch = epoch diff --git a/tests/conftest.py b/tests/conftest.py index 1aa87f87..33cb303d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,14 +9,12 @@ import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler -from torch.utils.data.sampler import BatchSampler, SequentialSampler from modalities.checkpointing.checkpoint_saving import CheckpointSaving from modalities.config.config import load_app_config_dict from modalities.dataloader.create_index import IndexGenerator from modalities.dataloader.dataloader import LLMDataLoader from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader -from modalities.dataloader.samplers import ResumableBatchSampler from modalities.evaluator import Evaluator from modalities.logging_broker.publisher import MessagePublisher from modalities.loss_functions import Loss @@ -226,13 +224,3 @@ def torch_distributed_cleanup(): else: # see https://pytorch.org/docs/2.4/_modules/torch/cuda.html#device_count torch.cuda._cached_device_count = None - - -@pytest.fixture(scope="function") -def resumable_batch_sampler() -> ResumableBatchSampler: - data_source = list(range(12))[::-1] # torch.range(0,11)[::-1].reshape(3, 4) - seq_sampler = SequentialSampler(data_source=data_source) - - seq_sampler = BatchSampler(sampler=seq_sampler, batch_size=3, drop_last=False) - sampler = ResumableBatchSampler(start_index=2, underlying_batch_sampler=seq_sampler) - return sampler diff --git a/tests/dataloader/yaml_configs/skipped_dataloader.yaml b/tests/dataloader/yaml_configs/skipped_dataloader.yaml index 9dbd91c4..ddd81bbe 100644 --- a/tests/dataloader/yaml_configs/skipped_dataloader.yaml +++ b/tests/dataloader/yaml_configs/skipped_dataloader.yaml @@ -28,14 +28,12 @@ train_dataset: sequence_length: ${settings.training.sequence_length} sample_key: ${settings.referencing_keys.sample_key} -skip_num_batches: +skip_num_samples: component_key: number_conversion - variant_key: local_num_batches_from_num_tokens + variant_key: num_samples_from_num_tokens config: - num_ranks: ${settings.cuda_env.world_size} - global_num_tokens: ${settings.training.global_num_seen_tokens} + num_tokens: ${settings.training.global_num_seen_tokens} sequence_length: ${settings.training.sequence_length} - local_micro_batch_size: ${settings.training.local_train_micro_batch_size} train_dataloader: component_key: data_loader @@ -44,9 +42,6 @@ train_dataloader: num_workers: 2 pin_memory: true dataloader_tag: train - skip_num_batches: - instance_key: skip_num_batches - pass_type: BY_REFERENCE dataset: instance_key: train_dataset pass_type: BY_REFERENCE @@ -58,15 +53,16 @@ train_dataloader: drop_last: true sampler: component_key: sampler - variant_key: distributed_sampler + variant_key: resumable_distributed_sampler config: - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - drop_last: true - shuffle: false dataset: instance_key: train_dataset pass_type: BY_REFERENCE + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + drop_last: true + skip_num_global_samples: ${skip_num_samples} collate_fn: instance_key: collate_fn pass_type: BY_REFERENCE From 6a6270911e9024d4ed8da79380dfde3fef999aa6 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Fri, 27 Sep 2024 00:04:41 +0200 Subject: [PATCH 23/33] refactor: removed legacy tests --- ...g_without_shuffling_but_skipped_batch.yaml | 39 ----------- .../test_distributed_repeating_dataloader.py | 69 ------------------- 2 files changed, 108 deletions(-) delete mode 100644 tests/dataloader/distributed/dist_repeating_dataloader_config_without_shuffling_but_skipped_batch.yaml delete mode 100644 tests/dataloader/distributed/test_distributed_repeating_dataloader.py diff --git a/tests/dataloader/distributed/dist_repeating_dataloader_config_without_shuffling_but_skipped_batch.yaml b/tests/dataloader/distributed/dist_repeating_dataloader_config_without_shuffling_but_skipped_batch.yaml deleted file mode 100644 index 297beda4..00000000 --- a/tests/dataloader/distributed/dist_repeating_dataloader_config_without_shuffling_but_skipped_batch.yaml +++ /dev/null @@ -1,39 +0,0 @@ -train_dataset: - component_key: dataset - variant_key: test - config: - num_samples: 8 - -train_dataloader: - component_key: data_loader - variant_key: repeating_data_loader - config: - reshuffle_after_epoch: false - num_epochs: 3 - dataloader: - component_key: data_loader - variant_key: default - config: - num_workers: 2 - pin_memory: true - dataloader_tag: "train" - skip_num_batches: 1 - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: 2 - drop_last: true - sampler: - component_key: sampler - variant_key: distributed_sampler - config: - rank: ${cuda_env:RANK} - num_replicas: ${cuda_env:WORLD_SIZE} - shuffle: false - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE \ No newline at end of file diff --git a/tests/dataloader/distributed/test_distributed_repeating_dataloader.py b/tests/dataloader/distributed/test_distributed_repeating_dataloader.py deleted file mode 100644 index 7f40cc97..00000000 --- a/tests/dataloader/distributed/test_distributed_repeating_dataloader.py +++ /dev/null @@ -1,69 +0,0 @@ -import json -import os -from pathlib import Path - -import pytest -import torch -import torch.distributed as dist -from pydantic import BaseModel - -from modalities.__main__ import Main -from modalities.config.config import ProcessGroupBackendType, PydanticLLMDataLoaderIFType -from modalities.running_env.cuda_env import CudaEnv -from tests.dataloader.dummy_sequential_dataset import TestDataset, TestDatasetConfig - -working_dir = Path(os.path.dirname(__file__)) -tmp_folder = working_dir / "../../tmp" - - -class DataloaderInstantiationModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType - - -@pytest.mark.skipif( - "RANK" not in os.environ or torch.cuda.device_count() < 2, - reason="This e2e test requires 2 GPUs and a torchrun distributed environment.", -) -def test_resumable_dataloader_without_shuffling(): - # we test that the distributed sampler provides each process with the correct subset of the dataset. - # In the first epoch we expect the first step to be skipped but for the subsequent epochs we expect - # all dataset samples. - # Given a sequence of [0, 1, 2, 3, 4, 5, 6, 7, 8] we want each of the two processes to have the - # following batches after three epochs - # to receive [[4, 6], [0, 2], [4, 6], [0, 2], [4, 6]] and - # [[5, 7], [1, 3], [5, 7], [1, 3], [5, 7]], respectively. - - config_file_path = working_dir / "dist_repeating_dataloader_config_without_shuffling_but_skipped_batch.yaml" - - main = Main(config_file_path) - with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): - main.add_custom_component( - component_key="dataset", - variant_key="test", - custom_component=TestDataset, - custom_config=TestDatasetConfig, - ) - components = main.build_components(components_model_type=DataloaderInstantiationModel) - - repeating_dataloader = components.train_dataloader - num_samples = len(repeating_dataloader.dataloader.dataset) - - # each epoch has 2 batches of size 2, we want two skip the first batch in the - # first epoch and have 3 epochs in total - batches = [batch.tolist() for batch in repeating_dataloader] - - rank = dist.get_rank() - with open(tmp_folder / f"rank_{rank}_batches.json", "w") as f: - json.dump(batches, f) - - dist.barrier() - - with open(tmp_folder / "rank_0_batches.json") as f: - rank_0_batches = torch.tensor(json.load(f)) - - with open(tmp_folder / "rank_1_batches.json") as f: - rank_1_batches = torch.tensor(json.load(f)) - - samples = [i.item() for item in zip(rank_0_batches.flatten(), rank_1_batches.flatten()) for i in item] - - assert samples == (list(range(num_samples)) * 3)[4:] From b0ac3342c8dd39ed89d50b01f39b4e23dd58794c Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Fri, 27 Sep 2024 00:05:13 +0200 Subject: [PATCH 24/33] refactor: fixed e2e tests --- ...dist_dataloader_config_with_shuffling.yaml | 4 ++-- ...ig_with_shuffling_and_skipped_batches.yaml | 4 ++-- ...t_dataloader_config_without_shuffling.yaml | 4 ++-- .../test_distributed_dataloader.py | 6 ++--- .../end2end_tests/gpt2_train_num_steps_8.yaml | 16 +++++++------- .../gpt2_warm_start_from_step_4.yaml | 22 +++++++++---------- tests/end2end_tests/test_fsdp_warmstart.py | 9 +++++--- tests/run_distributed_tests.sh | 11 ++++------ 8 files changed, 37 insertions(+), 39 deletions(-) diff --git a/tests/dataloader/distributed/dist_dataloader_config_with_shuffling.yaml b/tests/dataloader/distributed/dist_dataloader_config_with_shuffling.yaml index 17106d91..949d77ad 100644 --- a/tests/dataloader/distributed/dist_dataloader_config_with_shuffling.yaml +++ b/tests/dataloader/distributed/dist_dataloader_config_with_shuffling.yaml @@ -11,7 +11,6 @@ train_dataloader: num_workers: 2 pin_memory: true dataloader_tag: train - skip_num_batches: 0 dataset: instance_key: train_dataset pass_type: BY_REFERENCE @@ -23,12 +22,13 @@ train_dataloader: drop_last: true sampler: component_key: sampler - variant_key: distributed_sampler + variant_key: resumable_distributed_sampler config: rank: ${cuda_env:RANK} num_replicas: ${cuda_env:WORLD_SIZE} shuffle: true seed: 0 + skip_num_global_samples: 0 dataset: instance_key: train_dataset pass_type: BY_REFERENCE \ No newline at end of file diff --git a/tests/dataloader/distributed/dist_dataloader_config_with_shuffling_and_skipped_batches.yaml b/tests/dataloader/distributed/dist_dataloader_config_with_shuffling_and_skipped_batches.yaml index 9841886f..2a430a95 100644 --- a/tests/dataloader/distributed/dist_dataloader_config_with_shuffling_and_skipped_batches.yaml +++ b/tests/dataloader/distributed/dist_dataloader_config_with_shuffling_and_skipped_batches.yaml @@ -11,7 +11,6 @@ train_dataloader: num_workers: 2 pin_memory: true dataloader_tag: train - skip_num_batches: 1 dataset: instance_key: train_dataset pass_type: BY_REFERENCE @@ -23,12 +22,13 @@ train_dataloader: drop_last: true sampler: component_key: sampler - variant_key: distributed_sampler + variant_key: resumable_distributed_sampler config: rank: ${cuda_env:RANK} num_replicas: ${cuda_env:WORLD_SIZE} shuffle: true seed: 0 + skip_num_global_samples: 4 # num_batches (1) * world_size (2) * local_micro_batch_size (2) dataset: instance_key: train_dataset pass_type: BY_REFERENCE \ No newline at end of file diff --git a/tests/dataloader/distributed/dist_dataloader_config_without_shuffling.yaml b/tests/dataloader/distributed/dist_dataloader_config_without_shuffling.yaml index 9fda73af..e6d44637 100644 --- a/tests/dataloader/distributed/dist_dataloader_config_without_shuffling.yaml +++ b/tests/dataloader/distributed/dist_dataloader_config_without_shuffling.yaml @@ -11,7 +11,6 @@ train_dataloader: num_workers: 2 pin_memory: true dataloader_tag: train - skip_num_batches: 0 dataset: instance_key: train_dataset pass_type: BY_REFERENCE @@ -23,11 +22,12 @@ train_dataloader: drop_last: true sampler: component_key: sampler - variant_key: distributed_sampler + variant_key: resumable_distributed_sampler config: rank: ${cuda_env:RANK} num_replicas: ${cuda_env:WORLD_SIZE} shuffle: false + skip_num_global_samples: 0 dataset: instance_key: train_dataset pass_type: BY_REFERENCE \ No newline at end of file diff --git a/tests/dataloader/distributed/test_distributed_dataloader.py b/tests/dataloader/distributed/test_distributed_dataloader.py index 0038d04a..b1f80b68 100644 --- a/tests/dataloader/distributed/test_distributed_dataloader.py +++ b/tests/dataloader/distributed/test_distributed_dataloader.py @@ -25,7 +25,7 @@ class DataloaderInstantiationModel(BaseModel): "RANK" not in os.environ or torch.cuda.device_count() < 2, reason="This e2e test requires 2 GPUs and a torchrun distributed environment.", ) -def test_resumable_dataloader_without_shuffling(): +def test_dataloader_without_shuffling(): # we test that the distributed sampler provides each process with the correct subset of the dataset # Given a sequence of [0, 1, 2, 3, 4, 5, 6, 7, 8] we want each of the two processes # to receive [[0, 2], [4, 6]] and [[1, 3], [5, 7]], respectively. @@ -69,7 +69,7 @@ def test_resumable_dataloader_without_shuffling(): "RANK" not in os.environ or torch.cuda.device_count() < 2, reason="This e2e test requires 2 GPUs and a torchrun distributed environment.", ) -def test_resumable_dataloader_with_shuffling_without_skipping(): +def test_dataloader_with_shuffling_without_skipping(): # we test that the distributed sampler provides each process with the correct RANDOM subset of the dataset # Given a sequence of [0, 1, 2, 3, 4, 5, 6, 7, 8] we want each of the two processes # to receive two batches of size two without overlap, e.g., [[2, 0], [5, 6]] and [[7, 3], [4, 1]], respectively. @@ -114,7 +114,7 @@ def test_resumable_dataloader_with_shuffling_without_skipping(): "RANK" not in os.environ or torch.cuda.device_count() < 2, reason="This e2e test requires 2 GPUs and a torchrun distributed environment.", ) -def test_resumable_dataloader_with_shuffling_and_skipped_batches(): +def test_dataloader_with_shuffling_and_skipped_batches(): # we test that the distributed sampler provides each process with the correct RANDOM subset of the dataset # additionally we skip one batch # Given a sequence of [0, 1, 2, 3, 4, 5, 6, 7, 8] we want each of the two processes diff --git a/tests/end2end_tests/gpt2_train_num_steps_8.yaml b/tests/end2end_tests/gpt2_train_num_steps_8.yaml index 0ca58132..4954e6a9 100644 --- a/tests/end2end_tests/gpt2_train_num_steps_8.yaml +++ b/tests/end2end_tests/gpt2_train_num_steps_8.yaml @@ -47,10 +47,10 @@ settings: training_progress: global_num_seen_tokens: 0 num_seen_steps: 0 - local_num_seen_batches: 0 + num_seen_samples: 0 last_step: -1 -collate_fn: +collate_fn: component_key: collate_fn variant_key: gpt_2_llm_collator config: @@ -72,7 +72,6 @@ train_dataloader: num_workers: 2 pin_memory: true dataloader_tag: train - skip_num_batches: ${settings.training_progress.local_num_seen_batches} dataset: instance_key: train_dataset pass_type: BY_REFERENCE @@ -84,16 +83,17 @@ train_dataloader: drop_last: true sampler: component_key: sampler - variant_key: distributed_sampler + variant_key: resumable_distributed_sampler config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE rank: ${settings.cuda_env.global_rank} num_replicas: ${settings.cuda_env.world_size} shuffle: true - drop_last: true seed: 42 - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} collate_fn: instance_key: collate_fn pass_type: BY_REFERENCE diff --git a/tests/end2end_tests/gpt2_warm_start_from_step_4.yaml b/tests/end2end_tests/gpt2_warm_start_from_step_4.yaml index 03ba9c16..1a7c9da6 100644 --- a/tests/end2end_tests/gpt2_warm_start_from_step_4.yaml +++ b/tests/end2end_tests/gpt2_warm_start_from_step_4.yaml @@ -47,14 +47,12 @@ settings: variant_key: num_seen_steps_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} - local_num_seen_batches: # for the dataloader + num_seen_samples: component_key: number_conversion - variant_key: local_num_batches_from_num_tokens + variant_key: num_samples_from_num_tokens config: - num_ranks: ${settings.cuda_env.world_size} - global_num_tokens: ${settings.training_progress.global_num_seen_tokens} + num_tokens: ${settings.training_progress.global_num_seen_tokens} sequence_length: ${settings.step_profile.sequence_length} - local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} last_step: # for the scheduler component_key: number_conversion variant_key: last_step_from_checkpoint_path @@ -66,7 +64,7 @@ settings: model_checkpoint_path: eid_0-model-seen_steps_4-seen_tokens_2048-target_steps_15-target_tokens_7680.bin optimizer_checkpoint_path: eid_0-optimizer-seen_steps_4-seen_tokens_2048-target_steps_15-target_tokens_7680.bin -collate_fn: +collate_fn: component_key: collate_fn variant_key: gpt_2_llm_collator config: @@ -88,7 +86,6 @@ train_dataloader: num_workers: 2 pin_memory: true dataloader_tag: train - skip_num_batches: ${settings.training_progress.local_num_seen_batches} dataset: instance_key: train_dataset pass_type: BY_REFERENCE @@ -100,16 +97,17 @@ train_dataloader: drop_last: true sampler: component_key: sampler - variant_key: distributed_sampler + variant_key: resumable_distributed_sampler config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE rank: ${settings.cuda_env.global_rank} num_replicas: ${settings.cuda_env.world_size} shuffle: true - drop_last: true seed: 42 - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} collate_fn: instance_key: collate_fn pass_type: BY_REFERENCE diff --git a/tests/end2end_tests/test_fsdp_warmstart.py b/tests/end2end_tests/test_fsdp_warmstart.py index f41233ff..bb1b43e1 100644 --- a/tests/end2end_tests/test_fsdp_warmstart.py +++ b/tests/end2end_tests/test_fsdp_warmstart.py @@ -46,6 +46,7 @@ class SaveAllResultSubscriberConfig(BaseModel): class TrainDataloaderInstantiationModel(BaseModel): + settings: TrainingComponentsInstantiationModel.Settings train_dataloader: PydanticLLMDataLoaderIFType @@ -200,7 +201,9 @@ def test_warmstart_dataloader(self): custom_component=SaveAllResultSubscriber, custom_config=SaveAllResultSubscriberConfig, ) - components_1 = main_obj_1.build_components(components_model_type=TrainDataloaderInstantiationModel) + components_1: TrainDataloaderInstantiationModel = main_obj_1.build_components( + components_model_type=TrainDataloaderInstantiationModel + ) dataloader_1: LLMDataLoader = components_1.train_dataloader dl_1_samples = [s for s in dataloader_1] @@ -216,14 +219,14 @@ def test_warmstart_dataloader(self): # fast forward the first dataloader - num_skip_steps = dataloader_2.fast_forward_batch_id + num_skip_steps = components_2.settings.training_progress.num_seen_steps # make sure that we actually skip as defined in the config assert num_skip_steps == 4 assert len(dl_1_samples) == num_skip_steps + len(dl_2_samples) # make sure that the first dataloader is not skipped - assert dataloader_1.fast_forward_batch_id == 0 + assert components_1.settings.training_progress.num_seen_steps == 0 # iterate through both sample lists from the dataloaders # and assert the equality of the samples diff --git a/tests/run_distributed_tests.sh b/tests/run_distributed_tests.sh index ae39ad82..6994bd5a 100755 --- a/tests/run_distributed_tests.sh +++ b/tests/run_distributed_tests.sh @@ -40,13 +40,10 @@ COVERAGE_FILE=.coverage_reports/.coverage.part1 CUDA_VISIBLE_DEVICES=$DEV0,$DEV1 COVERAGE_FILE=.coverage_reports/.coverage.part2 CUDA_VISIBLE_DEVICES=$DEV0,$DEV1 coverage run --rcfile=.coveragerc --parallel $(which torchrun) --rdzv-endpoint localhost:29502 --nnodes 1 --nproc_per_node 2 $(which pytest) tests/end2end_tests/test_fsdp_warmstart.py -k "test_warm_start" $COVERAGE COVERAGE_FILE=.coverage_reports/.coverage.part3 CUDA_VISIBLE_DEVICES=$DEV0,$DEV1 coverage run --rcfile=.coveragerc --parallel $(which torchrun) --rdzv-endpoint localhost:29502 --nnodes 1 --nproc_per_node 2 $(which pytest) tests/end2end_tests/test_fsdp_warmstart.py -k "test_warmstart_dataloader" $COVERAGE -# # test_distributed_repeating_dataloader -COVERAGE_FILE=.coverage_reports/.coverage.part4 CUDA_VISIBLE_DEVICES=$DEV0,$DEV1 coverage run --rcfile=.coveragerc --parallel $(which torchrun) --rdzv-endpoint localhost:29502 --nnodes 1 --nproc_per_node 2 $(which pytest) tests/dataloader/distributed/test_distributed_repeating_dataloader.py -k "test_resumable_dataloader_without_shuffling" $COVERAGE - # # test_distributed_dataloader -COVERAGE_FILE=.coverage_reports/.coverage.part5 CUDA_VISIBLE_DEVICES=$DEV0,$DEV1 coverage run --rcfile=.coveragerc --parallel $(which torchrun) --rdzv-endpoint localhost:29502 --nnodes 1 --nproc_per_node 2 $(which pytest) tests/dataloader/distributed/test_distributed_dataloader.py -k "test_resumable_dataloader_without_shuffling" $COVERAGE -COVERAGE_FILE=.coverage_reports/.coverage.part6 CUDA_VISIBLE_DEVICES=$DEV0,$DEV1 coverage run --rcfile=.coveragerc --parallel $(which torchrun) --rdzv-endpoint localhost:29502 --nnodes 1 --nproc_per_node 2 $(which pytest) tests/dataloader/distributed/test_distributed_dataloader.py -k "test_resumable_dataloader_with_shuffling_without_skipping" $COVERAGE -COVERAGE_FILE=.coverage_reports/.coverage.part7 CUDA_VISIBLE_DEVICES=$DEV0,$DEV1 coverage run --rcfile=.coveragerc --parallel $(which torchrun) --rdzv-endpoint localhost:29502 --nnodes 1 --nproc_per_node 2 $(which pytest) tests/dataloader/distributed/test_distributed_dataloader.py -k "test_resumable_dataloader_with_shuffling_and_skipped_batches" $COVERAGE +COVERAGE_FILE=.coverage_reports/.coverage.part5 CUDA_VISIBLE_DEVICES=$DEV0,$DEV1 coverage run --rcfile=.coveragerc --parallel $(which torchrun) --rdzv-endpoint localhost:29502 --nnodes 1 --nproc_per_node 2 $(which pytest) tests/dataloader/distributed/test_distributed_dataloader.py -k "test_dataloader_without_shuffling" $COVERAGE +COVERAGE_FILE=.coverage_reports/.coverage.part6 CUDA_VISIBLE_DEVICES=$DEV0,$DEV1 coverage run --rcfile=.coveragerc --parallel $(which torchrun) --rdzv-endpoint localhost:29502 --nnodes 1 --nproc_per_node 2 $(which pytest) tests/dataloader/distributed/test_distributed_dataloader.py -k "test_dataloader_with_shuffling_without_skipping" $COVERAGE +COVERAGE_FILE=.coverage_reports/.coverage.part7 CUDA_VISIBLE_DEVICES=$DEV0,$DEV1 coverage run --rcfile=.coveragerc --parallel $(which torchrun) --rdzv-endpoint localhost:29502 --nnodes 1 --nproc_per_node 2 $(which pytest) tests/dataloader/distributed/test_distributed_dataloader.py -k "test_dataloader_with_shuffling_and_skipped_batches" $COVERAGE # # test optimizer COVERAGE_FILE=.coverage_reports/.coverage.part8 CUDA_VISIBLE_DEVICES=$DEV0 coverage run --rcfile=.coveragerc --parallel $(which torchrun) --rdzv-endpoint localhost:29502 --nnodes 1 --nproc_per_node 1 $(which pytest) tests/test_optimizer_factory.py $COVERAGE @@ -58,5 +55,5 @@ COVERAGE_FILE=.coverage_reports/.coverage.part9 CUDA_VISIBLE_DEVICES=$DEV0 cover COVERAGE_FILE=.coverage_reports/.coverage.part10 CUDA_VISIBLE_DEVICES=$DEV0 coverage run --rcfile=.coveragerc --parallel $(which torchrun) --rdzv-endpoint localhost:29502 --nnodes 1 --nproc_per_node 1 $(which pytest) tests/utils/test_mfu.py $COVERAGE # test activation checkpointing -COVERAGE_FILE=.coverage_reports/.coverage.part11 CUDA_VISIBLE_DEVICES=$DEV0 coverage run --rcfile=.coveragerc --parallel $(which torchrun) --rdzv-endpoint localhost:29502 --nnodes 1 --nproc_per_node 1 $(which pytest) tests/training/test_activation_checkpointing.py $COVERAGE +COVERAGE_FILE=.coverage_reports/.coverage.part11 CUDA_VISIBLE_DEVICES=$DEV0,$DEV1 coverage run --rcfile=.coveragerc --parallel $(which torchrun) --rdzv-endpoint localhost:29502 --nnodes 1 --nproc_per_node 2 $(which pytest) tests/training/test_activation_checkpointing.py $COVERAGE From aed1d3dbe11fda36d226f462c2270141dae16e55 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Fri, 27 Sep 2024 10:40:35 +0200 Subject: [PATCH 25/33] chore: updated documentation --- CHANGELOG_DEV.md | 23 +++++++++++++++++++++++ docs/components/components.md | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/CHANGELOG_DEV.md b/CHANGELOG_DEV.md index 7d4fb0cd..5d6eda63 100644 --- a/CHANGELOG_DEV.md +++ b/CHANGELOG_DEV.md @@ -65,3 +65,26 @@ This [PR](https://github.com/Modalities/modalities/pull/236) removes all code re **Breaking changes:** * None + +## PR #261 Dataloader inefficiencies fix and combined dataset feature + +This PR addresses issue #258 (inefficiencies in the dataloader) and additionally introduces a combined dataset, where a dataset can now comprise a list of datasets and iterate over them. +As part of fixing the dataloader inefficiencies, we now implement the sample skipping functionality not on the dataloader level anymore but in an adapted version of the PyTorch `DistributedSampler`. I reran a warm start and the learning is equivalent to a full, non-warmstarted run. + +Screenshot 2024-09-27 at 10 36 19 + + +**General Changes** +* Introduced `ResumableDistributedSampler` which is a copy of the PyTorch `DistributedSampler` added with the feature to skip samples. This is from now on used for warmstarts instead of the `skip_num_samples` in the Dataloader. In case of skipping samples, the dataloader had to instantiate a `ResumableBatchSampler` which was internally iterating over all the dataset indices. For small datasets this was fine, but for larger datasets (in the trillion token range) this became a bottleneck at instantiation time: +https://github.com/Modalities/modalities/blob/b79d04d3e92d0845c5ec91f8dd41176fd543cb23/src/modalities/dataloader/samplers.py#L25-L28 +Skipping in the `ResumableDistributedSampler` is skipping in O(1) now. The `ResumableBatchSampler` was removed from the codebase. +* Replaced the packed index generation routine (inefficient due to for loop) +https://github.com/Modalities/modalities/blob/b79d04d3e92d0845c5ec91f8dd41176fd543cb23/src/modalities/dataloader/dataset.py#L331-L334 +with a vectorized version. +* added new `NumberConversion` routine `num_samples_from_num_tokens ` + +**Breaking Changes** +* Removed RepeatingDataloader, as a feature that was never actively used for running multiple epochs and had complex maintenance when refactoring the sampling. If needed we could reimpliment it. +* In the settings, the `training_progress` section has now `num_seen_samples` instead of `local_num_seen_batches `, as skipping is now done on the Sampler level and not on the dataloader level anymore +* `batch_size ` and `fast_forward_batch_id ` fields in the `LLMDataLoader ` are not neede anymore and were removed. + diff --git a/docs/components/components.md b/docs/components/components.md index a7079628..eeb625ae 100644 --- a/docs/components/components.md +++ b/docs/components/components.md @@ -77,7 +77,6 @@ |Component type | Component Version | Implementation | Configuration | Component Interface | Description | |---------------|--------------------|----------------|---------------|---------------------|-------------| | data_loader | default | [DataloaderFactory.get_dataloader](../../src/modalities/dataloader/dataloader_factory.py)| [LLMDataLoaderConfig](s../../src/modalities/config/config.py) | [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) | LLM Data loader extending pytorch data loader functionality | -| data_loader | repeating_data_loader | [DataloaderFactory.get_repeating_dataloader](../../src/modalities/dataloader/dataloader_factory.py)| [RepeatingDataLoaderConfig](../../src/modalities/config/config.py) | [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) | Data loader that repeats the given dataloader for the specified number of epochs. | ## Checkpointing @@ -119,6 +118,7 @@ |---------------|--------------------|----------------|---------------|---------------------|-------------| | number_conversion | local_num_batches_from_num_samples | [NumberConversion.get_local_num_batches_from_num_samples](../../src/modalities/utils/number_conversion.py)| [LocalNumBatchesFromNumSamplesConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of local batches for each rank, given the global number of samples and number of ranks. | | number_conversion | local_num_batches_from_num_tokens | [NumberConversion.get_local_num_batches_from_num_tokens](../../src/modalities/utils/number_conversion.py)| [LocalNumBatchesFromNumTokensConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of local batches for each rank, given the global number of tokens and number of ranks. | +| number_conversion | local_num_batches_from_num_tokens | [NumberConversion.get_num_samples_from_num_tokens](../../src/modalities/utils/number_conversion.py)| [NumSamplesFromNumTokensConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of global samples, given the global number of tokens and sequence length | | number_conversion | num_steps_from_num_samples | [NumberConversion.get_num_steps_from_num_samples](../../src/modalities/utils/number_conversion.py)| [NumStepsFromNumSamplesConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of steps given the global number of samples, local micro batch size and number of ranks. | | number_conversion | num_steps_from_num_tokens | [NumberConversion.get_num_steps_from_num_tokens](../../src/modalities/utils/number_conversion.py)| [NumStepsFromNumTokensConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of steps given the global number of tokens, local micro batch size and number of ranks. | | number_conversion | num_tokens_from_num_steps | [NumberConversion.get_num_tokens_from_num_steps](../../src/modalities/utils/number_conversion.py)| [NumTokensFromNumStepsConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of tokens from the number of steps, number of ranks, local micro batch size, global number of tokens, squence length and gradient accumulation steps | From 4d27cae9eb11f1d2fcd9cc3242b9806dc210d4d3 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Fri, 27 Sep 2024 13:55:49 +0200 Subject: [PATCH 26/33] chore: fixed minor path issue --- src/modalities/__main__.py | 2 +- tests/tests.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index fd139fef..6996c4a8 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -148,7 +148,7 @@ def entry_point_data_create_raw_index(src_path: Path, index_path: Path): index_path = LargeFileLinesReader.default_index_path(src_path, index_path) if index_path.exists(): - raise ValueError("index already exists. delete it or specify different output folder.") + raise ValueError(f"Index already exists in {index_path}. Delete it or specify different output folder.") print(f"reading raw data from {src_path}") print(f"writing index to {index_path}") diff --git a/tests/tests.py b/tests/tests.py index c4e0bc25..5a4adea1 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -106,9 +106,9 @@ def main(cpu: bool = False, single_gpu: bool = False, multi_gpu: bool = False, d # getting started example print("\n=== RUN GETTING STARTED EXAMPLE ===") - run_getting_started_example_directory = _ROOT_DIR / "examples" / "getting_started" + run_getting_started_example_directory = _ROOT_DIR / "tutorials" / "getting_started" run_getting_started_example_script = ( - _ROOT_DIR / "examples" / "getting_started" / "run_getting_started_example.sh" + _ROOT_DIR / "tutorials" / "getting_started" / "run_getting_started_example.sh" ) assert isfile( run_getting_started_example_script From 2e8a880729b5a623c2672cf45c50f0ab05ffe463 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Thu, 24 Oct 2024 00:00:42 +0200 Subject: [PATCH 27/33] refactor: improved test_skipped_and_distributed_dataloader_from_config --- tests/dataloader/test_dataloader.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 8e264100..2511c885 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -103,6 +103,17 @@ class DataloaderTestModel(BaseModel): dataset_indices_rank_0 = np.arange(0, 28, 2).reshape(-1, local_micro_batch_size)[skip_num_local_batches:] dataset_indices_rank_1 = np.arange(1, 29, 2).reshape(-1, local_micro_batch_size)[skip_num_local_batches:] + # make sure that the recreated dataset index with the hardcoded 28 elements + # fits the actual dataset used in the config + effective_dataset_length = len(components_rank_0.train_dataloader.dataset) // world_size // local_micro_batch_size + effective_dataset_length = effective_dataset_length * local_micro_batch_size * world_size + recalculated_dataset_length = ( + len(dataset_indices_rank_0.flatten()) + + len(dataset_indices_rank_1.flatten()) + + skip_num_local_batches * world_size * local_micro_batch_size + ) + assert recalculated_dataset_length == effective_dataset_length + assert np.all((dataset_indices_rank_0 == list(components_rank_0.train_dataloader.batch_sampler))) assert np.all((dataset_indices_rank_1 == list(components_rank_1.train_dataloader.batch_sampler))) From 6d9b88a1428ea7be46155a32fd5ed37e8a2608aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20L=C3=BCbbering?= <2804731+le1nux@users.noreply.github.com> Date: Thu, 24 Oct 2024 14:29:18 +0200 Subject: [PATCH 28/33] Update src/modalities/dataloader/dataset.py Co-authored-by: Felix Stollenwerk --- src/modalities/dataloader/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index f3139895..be5985db 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -392,7 +392,7 @@ def _generate_packing_index(self) -> List[Tuple[int, int]]: class CombinedDataset(Dataset): """Combines multiple datasets into one large dataset at runtime. - Note: When using this class to combine multiple `PackedMemMapDatasetes`, then each packed sample + Note: When using this class to combine multiple `PackedMemMapDataset`s, then each packed sample is packed from a single dataset (i.e., the samples are not mixed between datasets). In the Dataloader a batch will still contain packed samples from different datasets. """ From b067718a11df81df10540feeb61e881e94df1301 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20L=C3=BCbbering?= <2804731+le1nux@users.noreply.github.com> Date: Thu, 24 Oct 2024 14:38:47 +0200 Subject: [PATCH 29/33] Update tests/dataloader/samplers/test_distributed_samplers.py Co-authored-by: Felix Stollenwerk --- tests/dataloader/samplers/test_distributed_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dataloader/samplers/test_distributed_samplers.py b/tests/dataloader/samplers/test_distributed_samplers.py index dc7ab092..ac0c4e2e 100644 --- a/tests/dataloader/samplers/test_distributed_samplers.py +++ b/tests/dataloader/samplers/test_distributed_samplers.py @@ -50,7 +50,7 @@ def test_dropping_and_reusing( samples_left = len(dataset) - skip_num_global_samples padding_size = math.ceil(samples_left / num_replicas) * num_replicas - samples_left # when drop_last false, we reuse the last samples (i.e., maximum num_ranks -1) - # so that every data parallel ran, has a full last batch + # so that every data parallel ran has a full last batch padded_samples = dataset[:padding_size] assert dataset[skip_num_global_samples:cut_off_samples] + padded_samples == list( From cc9ef97cab9defdbd2c5dce6466a89728653941f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20L=C3=BCbbering?= <2804731+le1nux@users.noreply.github.com> Date: Thu, 24 Oct 2024 14:39:47 +0200 Subject: [PATCH 30/33] Update tests/dataloader/samplers/test_distributed_samplers.py Co-authored-by: Felix Stollenwerk --- tests/dataloader/samplers/test_distributed_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dataloader/samplers/test_distributed_samplers.py b/tests/dataloader/samplers/test_distributed_samplers.py index ac0c4e2e..de39bcfe 100644 --- a/tests/dataloader/samplers/test_distributed_samplers.py +++ b/tests/dataloader/samplers/test_distributed_samplers.py @@ -39,7 +39,7 @@ def test_dropping_and_reusing( samples = [[dataset[i] for i in sampler] for sampler in samplers] if drop_last: - # when drop_last true, we drop the last samples so that every data parallel rank + # if drop_last is true, we drop the last samples so that every data parallel rank # has the same number of samples. # Note that also means that the last, remaining samples (i.e., maximum num_ranks -1) # are not used at all From 09554b1d7321662f787ebcf561c2653ca55e9ea9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20L=C3=BCbbering?= <2804731+le1nux@users.noreply.github.com> Date: Thu, 24 Oct 2024 14:42:14 +0200 Subject: [PATCH 31/33] Update tests/dataloader/samplers/test_distributed_samplers.py Co-authored-by: Felix Stollenwerk --- tests/dataloader/samplers/test_distributed_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dataloader/samplers/test_distributed_samplers.py b/tests/dataloader/samplers/test_distributed_samplers.py index de39bcfe..868da3f8 100644 --- a/tests/dataloader/samplers/test_distributed_samplers.py +++ b/tests/dataloader/samplers/test_distributed_samplers.py @@ -12,7 +12,7 @@ (30, 0, False, 0, True, 0), # drop_last has no effect because integer divisible (30, 0, False, 0, False, 9), (30, 0, False, 0, True, 9), # drop_last has no effect because integer divisible - (30, 0, False, 0, True, 10), # drop_last has an effect because not integer divisible + (30, 0, False, 0, True, 10), # drop_last has an effect because not integer divisible (2 samples dropped) (30, 0, False, 0, False, 10), # we have to reuse the initial samples (1 sample) ], ) From 856fba77676b52933e6635b8f69335fb35849e03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20L=C3=BCbbering?= <2804731+le1nux@users.noreply.github.com> Date: Thu, 24 Oct 2024 14:46:13 +0200 Subject: [PATCH 32/33] Update tests/dataloader/samplers/test_distributed_samplers.py Co-authored-by: Felix Stollenwerk --- tests/dataloader/samplers/test_distributed_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dataloader/samplers/test_distributed_samplers.py b/tests/dataloader/samplers/test_distributed_samplers.py index 868da3f8..962c08db 100644 --- a/tests/dataloader/samplers/test_distributed_samplers.py +++ b/tests/dataloader/samplers/test_distributed_samplers.py @@ -41,7 +41,7 @@ def test_dropping_and_reusing( if drop_last: # if drop_last is true, we drop the last samples so that every data parallel rank # has the same number of samples. - # Note that also means that the last, remaining samples (i.e., maximum num_ranks -1) + # Note that also means that the last, remaining samples (i.e., maximum num_replicas - 1) # are not used at all cut_off_samples = len(dataset) - (len(dataset) - skip_num_global_samples) % num_replicas padded_samples = [] From c2e6b8c5bdd6883844a947c283b00213366ecf3a Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Thu, 24 Oct 2024 14:47:53 +0200 Subject: [PATCH 33/33] refactor: fixed typos --- src/modalities/dataloader/create_packed_data.py | 1 + src/modalities/dataloader/dataset.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/modalities/dataloader/create_packed_data.py b/src/modalities/dataloader/create_packed_data.py index fda2fd63..341e50db 100644 --- a/src/modalities/dataloader/create_packed_data.py +++ b/src/modalities/dataloader/create_packed_data.py @@ -329,6 +329,7 @@ def __init__(self, data_path: Path, load_index: Optional[bool] = True): Args: data_path (Path): The path to the packed data file. + load_index (bool, optional): Whether to load the index. Defaults to True. Raises: FileNotFoundError: If the packed data file is not found at the specified path. diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index f3139895..09751ef6 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -404,16 +404,16 @@ def __init__(self, datasets: List[Dataset]): datasets (List[Dataset]): A list of datasets to combine. """ self.datasets = datasets - self.cumulative_sizes = CombinedDataset._get_cummulated_sizes(datasets=datasets) + self.cumulative_sizes = CombinedDataset._get_cumulated_sizes(datasets=datasets) @staticmethod - def _get_cummulated_sizes(datasets: List[Dataset]) -> List[int]: + def _get_cumulated_sizes(datasets: List[Dataset]) -> List[int]: total = 0 - cummulated_sizes = [0] + cumulated_sizes = [0] for dataset in datasets: total += len(dataset) - cummulated_sizes.append(total) - return cummulated_sizes + cumulated_sizes.append(total) + return cumulated_sizes def _find_dataset_idx(self, idx: int) -> int: for i, cumulative_size in enumerate(self.cumulative_sizes):