From a5ea22cf27d11d558db471a844e5f88db63b3335 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Thu, 11 Jul 2024 10:00:12 +0200 Subject: [PATCH 1/7] feat: added fixed number of elements to ResumableBatchSampler --- src/modalities/dataloader/samplers.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/modalities/dataloader/samplers.py b/src/modalities/dataloader/samplers.py index c5ab2699..a21d171b 100644 --- a/src/modalities/dataloader/samplers.py +++ b/src/modalities/dataloader/samplers.py @@ -1,21 +1,31 @@ +from typing import Optional + from torch.utils.data import BatchSampler, Sampler class ResumableBatchSampler(Sampler): - def __init__(self, start_index: int, underlying_batch_sampler: BatchSampler): + def __init__( + self, start_index: int, underlying_batch_sampler: BatchSampler, max_num_elements: Optional[int] = None + ): """Sampler which starts at a specified batch index and continues sampling for for a given sampler. Works with normal samplers and BatchSamplers. Args: start_index (int): index to start sampling from existing_sampler (Sampler): Sampler from which we want to continue + max_num_elements (Optional[int]): The maximum number of elements the sampler returns. Default None. """ self.start_index = start_index + self.max_num_elements = max_num_elements self.underlying_batch_sampler = underlying_batch_sampler # NOTE: we are only iterating ove the indices not the actual data # so this is relatively cheap self.indices = list(iter(self.underlying_batch_sampler)) + # We discard the samples that come after max_num_elements + # NOTE, that skipping is implemented in __iter__ and __len__. + if self.max_num_elements is not None: + self.indices = self.indices[:max_num_elements] def __iter__(self): return iter(self.indices[self.start_index :]) From f63107527df6a036185cf29efec5b6c3bb63ae3e Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Thu, 11 Jul 2024 10:00:36 +0200 Subject: [PATCH 2/7] feat: added fixed number of batches to dataloader --- config_files/training/config_lorem_ipsum.yaml | 10 ++++++++++ src/modalities/dataloader/dataloader_factory.py | 17 ++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/config_files/training/config_lorem_ipsum.yaml b/config_files/training/config_lorem_ipsum.yaml index fc1d8926..38f70915 100644 --- a/config_files/training/config_lorem_ipsum.yaml +++ b/config_files/training/config_lorem_ipsum.yaml @@ -9,6 +9,7 @@ settings: checkpointing_interval_in_steps: 4 evaluation_interval_in_steps: 2 global_num_seen_tokens: 0 + global_num_train_tokens: 1000 activation_checkpointing_modules: [GPT2Block] gradient_acc_steps: 2 local_train_micro_batch_size: 1 @@ -43,6 +44,15 @@ train_dataloader: pin_memory: true shuffle: false dataloader_tag: "train" + fixed_num_batches: + component_key: number_conversion + variant_key: num_steps_from_num_tokens + config: + num_ranks: ${settings.cuda_env.world_size} + local_micro_batch_size: ${settings.training.local_train_micro_batch_size} + global_num_tokens: ${settings.training.global_num_train_tokens} + sequence_length: ${settings.training.sequence_length} + gradient_acc_steps: ${settings.training.gradient_acc_steps} dataset: instance_key: train_dataset pass_type: BY_REFERENCE diff --git a/src/modalities/dataloader/dataloader_factory.py b/src/modalities/dataloader/dataloader_factory.py index 006bf11d..f3398821 100644 --- a/src/modalities/dataloader/dataloader_factory.py +++ b/src/modalities/dataloader/dataloader_factory.py @@ -5,6 +5,7 @@ from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader from modalities.dataloader.samplers import ResumableBatchSampler +from modalities.exceptions import ConfigError class DataloaderFactory: @@ -18,6 +19,7 @@ def get_dataloader( pin_memory: bool, shuffle: bool, skip_num_batches: Optional[int] = 0, + fixed_num_batches: Optional[int] = None, ) -> LLMDataLoader: """Factory method for the instantiation of LLMDataLoader @@ -34,11 +36,23 @@ def get_dataloader( skip_num_batches must not be confused with the number of optimizer steps! skip_num_batches = num optimizer steps * gradient accumulation steps Defaults to 0. + fixed_num_batches: (int, optional): Fixed length of the dataloader by cutting off subsequent batches. + Make sure that the dataloader has at least fixed_num_batches. Defaults to None. Returns: LLMDataLoader: Instance of LLMDataLoader """ - batch_sampler = ResumableBatchSampler(start_index=skip_num_batches, underlying_batch_sampler=batch_sampler) + + batch_sampler = ResumableBatchSampler( + start_index=skip_num_batches, underlying_batch_sampler=batch_sampler, max_num_elements=fixed_num_batches + ) + + # make sure that the batch sampler has enough elements such that we can fix the number of batches to num_batches + if fixed_num_batches is not None and len(batch_sampler) < fixed_num_batches - skip_num_batches: + raise ConfigError( + f"The dataloader contains only {len(batch_sampler)} batches, which is less than " + f"specified fixed amount of batches of {fixed_num_batches}." + ) dataloader = LLMDataLoader( dataloader_tag=dataloader_tag, @@ -49,6 +63,7 @@ def get_dataloader( pin_memory=pin_memory, shuffle=shuffle, ) + return dataloader @staticmethod From 2e13f05cb1741064239795a0c12e3c9f1658be9e Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Thu, 18 Jul 2024 00:54:03 +0200 Subject: [PATCH 3/7] fix: fixed error in fixed_num_batches calculation --- src/modalities/config/config.py | 1 + src/modalities/dataloader/dataloader_factory.py | 4 +++- src/modalities/utils/number_conversion.py | 3 +-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 1db00b5a..f7e6877f 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -313,6 +313,7 @@ class LLMDataLoaderConfig(BaseModel): pin_memory: bool shuffle: bool skip_num_batches: Optional[int] = 0 + fixed_num_batches: Optional[int] = None class RepeatingDataLoaderConfig(BaseModel): diff --git a/src/modalities/dataloader/dataloader_factory.py b/src/modalities/dataloader/dataloader_factory.py index f3398821..01261aa6 100644 --- a/src/modalities/dataloader/dataloader_factory.py +++ b/src/modalities/dataloader/dataloader_factory.py @@ -37,7 +37,9 @@ def get_dataloader( skip_num_batches = num optimizer steps * gradient accumulation steps Defaults to 0. fixed_num_batches: (int, optional): Fixed length of the dataloader by cutting off subsequent batches. - Make sure that the dataloader has at least fixed_num_batches. Defaults to None. + Note that these are NOT the global number of batches, but the amount of batches that an + individual rank sees. Make sure that the dataloader has at least fixed_num_batches. + Defaults to None. Returns: LLMDataLoader: Instance of LLMDataLoader diff --git a/src/modalities/utils/number_conversion.py b/src/modalities/utils/number_conversion.py index 71979a9a..0e9da3a9 100644 --- a/src/modalities/utils/number_conversion.py +++ b/src/modalities/utils/number_conversion.py @@ -54,8 +54,7 @@ def get_local_num_batches_from_num_samples(num_ranks: int, global_num_samples: i def get_local_num_batches_from_num_tokens(num_ranks: int, global_num_tokens: int, sequence_length: int) -> int: """Calculates the number of local batches for each rank, given the global number of tokens and number of ranks. - This helper function is primarily used to calculate the number of batches to - skip when resuming a dataloader during warmstart. + This helper function is primarily used to calculate a dataloader's number of batches (total and to skip) Args: num_ranks (int): _description_ From 540fd8cf18a300d53f3b8507dd6a6e0af7ff4666 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Thu, 18 Jul 2024 00:54:42 +0200 Subject: [PATCH 4/7] feat: implemented test for fixed_num_batches in dataloader --- src/modalities/dataloader/dataset.py | 15 ++++- tests/dataloader/test_dataloader.py | 58 ++++++++++++++++++- .../dataloader_with_fixed_num_batches.yaml | 56 ++++++++++++++++++ 3 files changed, 127 insertions(+), 2 deletions(-) create mode 100644 tests/dataloader/yaml_configs/dataloader_with_fixed_num_batches.yaml diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 1730971f..ae536cdc 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -1,8 +1,9 @@ from __future__ import annotations +from collections.abc import Sequence from enum import Enum from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import jq import numpy as np @@ -73,6 +74,18 @@ def _create_random_sample(self): return sample +class SequenceDataset(Dataset): + def __init__(self, sequence: Sequence): + super().__init__(raw_data_path=None, sample_key=None) + self.sequence = sequence + + def __len__(self) -> int: + return len(self.sequence) + + def __getitem__(self, idx: int) -> Any: + return self.sequence[idx] + + class MemMapDataset(Dataset): def __init__( self, diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index bdb623c5..f8df9039 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -1,8 +1,9 @@ import math from pathlib import Path -from typing import Dict +from typing import Dict, List import numpy as np +import pytest import torch from pydantic import BaseModel from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler @@ -11,7 +12,9 @@ from modalities.config.config import load_app_config_dict from modalities.config.pydanctic_if_types import PydanticLLMDataLoaderIFType from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader +from modalities.dataloader.dataset import SequenceDataset from modalities.dataloader.samplers import ResumableBatchSampler +from modalities.models.gpt2.collator import CollateFnIF from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry @@ -223,3 +226,56 @@ class DataloaderTestModel(BaseModel): for batch_1, batch_2 in zip(batches_rank_0, batches_rank_1): assert ~(batch_1.samples["input_ids"] == batch_2.samples["input_ids"]).all() assert ~(batch_1.targets["target_ids"] == batch_2.targets["target_ids"]).all() + + +@pytest.mark.parametrize( + "global_rank", + [0, 1], +) +def test_dataloader_with_fixed_num_batches(global_rank): + class DataloaderTestModel(BaseModel): + train_dataloader: PydanticLLMDataLoaderIFType + fixed_num_batches: int + + class IdentityCollateFn(CollateFnIF): + def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> List[Dict[str, torch.Tensor]]: + return batch + + root_dir = Path(__file__).parents[0] + + config_path = root_dir / "yaml_configs/dataloader_with_fixed_num_batches.yaml" + # we inject a prebuilt dataset and collate_fn, as well as, the global rank constant from outside + dataset = SequenceDataset(list(range(1000))) + config_dict = load_app_config_dict(config_path) + config_dict["settings"]["cuda_env"]["global_rank"] = global_rank + config_dict["train_dataloader"]["config"]["batch_sampler"]["config"]["sampler"]["config"]["rank"] = global_rank + config_dict["train_dataset"] = dataset + config_dict["collate_fn"] = IdentityCollateFn() + + # build the remaining components + registry = Registry(COMPONENTS) + component_factory = ComponentFactory(registry=registry) + components: DataloaderTestModel = component_factory.build_components( + config_dict=config_dict, components_model_type=DataloaderTestModel + ) + dataloader = components.train_dataloader + + # calculate the fixed_num_batches and + # compare it with the one calculated during the component build and the dataloader length + cfg = config_dict["settings"]["training"] + world_size = config_dict["settings"]["cuda_env"]["world_size"] + calculated_fixed_num_batches = cfg["global_num_train_tokens"] // cfg["sequence_length"] // world_size + assert calculated_fixed_num_batches == components.fixed_num_batches + assert len(dataloader) == calculated_fixed_num_batches + + # We make sure that the dataloader outputs the correct batches as follows: + # The dataset contains 1000 samples (NOTE that we neglected squence_length and made each sample an integer value) + # we calculated 16 batches above per rank and have 2 ranks in total. + # Therefore the dataloader for rank 0 returns 16 ordered batches of batch_size 2. + # The batches are ordered and not shuffled as per YAML configuration. + # We expect the following output: + # [[0, 2], [4, 6], [8, 10], ..., [56, 58], [60, 62]] (global_rank=0) + # [[1, 3], [5, 7], [9, 11], ..., [57, 59], [61, 63]] (global_rank=1) + calculated_dataloader_content = np.array(list(range(global_rank, 64 + global_rank, 2))).reshape(-1, 2).tolist() + actual_dataloader_content = [i for i in dataloader] + assert calculated_dataloader_content == actual_dataloader_content diff --git a/tests/dataloader/yaml_configs/dataloader_with_fixed_num_batches.yaml b/tests/dataloader/yaml_configs/dataloader_with_fixed_num_batches.yaml new file mode 100644 index 00000000..bc3c3b0e --- /dev/null +++ b/tests/dataloader/yaml_configs/dataloader_with_fixed_num_batches.yaml @@ -0,0 +1,56 @@ +settings: + referencing_keys: + sample_key: input_ids + target_key: target_ids + training: + local_train_micro_batch_size: 2 + global_num_seen_tokens: 0 + global_num_train_tokens: 128 + sequence_length: 4 + cuda_env: + global_rank: 0 + world_size: 2 + +fixed_num_batches: + component_key: number_conversion + variant_key: local_num_batches_from_num_tokens + config: + num_ranks: ${settings.cuda_env.world_size} + global_num_tokens: ${settings.training.global_num_train_tokens} + sequence_length: ${settings.training.sequence_length} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: train + skip_num_batches: 0 + fixed_num_batches: + instance_key: fixed_num_batches + pass_type: BY_REFERENCE + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + drop_last: true + shuffle: false + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE From 908f9f68f6a32bd6ab3b63787f8c4ca9b9468e66 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Thu, 18 Jul 2024 00:54:59 +0200 Subject: [PATCH 5/7] refactor: removed fixed_num_batches from config_lorem_ipsum.yaml --- config_files/training/config_lorem_ipsum.yaml | 9 --------- 1 file changed, 9 deletions(-) diff --git a/config_files/training/config_lorem_ipsum.yaml b/config_files/training/config_lorem_ipsum.yaml index 38f70915..fa1e085e 100644 --- a/config_files/training/config_lorem_ipsum.yaml +++ b/config_files/training/config_lorem_ipsum.yaml @@ -9,7 +9,6 @@ settings: checkpointing_interval_in_steps: 4 evaluation_interval_in_steps: 2 global_num_seen_tokens: 0 - global_num_train_tokens: 1000 activation_checkpointing_modules: [GPT2Block] gradient_acc_steps: 2 local_train_micro_batch_size: 1 @@ -44,14 +43,6 @@ train_dataloader: pin_memory: true shuffle: false dataloader_tag: "train" - fixed_num_batches: - component_key: number_conversion - variant_key: num_steps_from_num_tokens - config: - num_ranks: ${settings.cuda_env.world_size} - local_micro_batch_size: ${settings.training.local_train_micro_batch_size} - global_num_tokens: ${settings.training.global_num_train_tokens} - sequence_length: ${settings.training.sequence_length} gradient_acc_steps: ${settings.training.gradient_acc_steps} dataset: instance_key: train_dataset From 89c2b25d134552941dc14c7689fa75ddec1d9b79 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Fri, 19 Jul 2024 12:58:25 +0200 Subject: [PATCH 6/7] refactor: moved SequenceDataset to test --- src/modalities/dataloader/dataset.py | 15 +-------------- tests/dataloader/test_dataloader.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index ae536cdc..1730971f 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -1,9 +1,8 @@ from __future__ import annotations -from collections.abc import Sequence from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import jq import numpy as np @@ -74,18 +73,6 @@ def _create_random_sample(self): return sample -class SequenceDataset(Dataset): - def __init__(self, sequence: Sequence): - super().__init__(raw_data_path=None, sample_key=None) - self.sequence = sequence - - def __len__(self) -> int: - return len(self.sequence) - - def __getitem__(self, idx: int) -> Any: - return self.sequence[idx] - - class MemMapDataset(Dataset): def __init__( self, diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index f8df9039..44cb55e6 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -1,6 +1,7 @@ import math +from collections.abc import Sequence from pathlib import Path -from typing import Dict, List +from typing import Any, Dict, List import numpy as np import pytest @@ -12,13 +13,25 @@ from modalities.config.config import load_app_config_dict from modalities.config.pydanctic_if_types import PydanticLLMDataLoaderIFType from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader -from modalities.dataloader.dataset import SequenceDataset +from modalities.dataloader.dataset import Dataset from modalities.dataloader.samplers import ResumableBatchSampler from modalities.models.gpt2.collator import CollateFnIF from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry +class SequenceDataset(Dataset): + def __init__(self, sequence: Sequence): + super().__init__(raw_data_path=None, sample_key=None) + self.sequence = sequence + + def __len__(self) -> int: + return len(self.sequence) + + def __getitem__(self, idx: int) -> Any: + return self.sequence[idx] + + def test_resumable_dataloader(): batch_size = 3 start_index = 2 From e9e93ab9ea7e5abc383f6438b5df4f3fe7f5116c Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Fri, 19 Jul 2024 12:58:59 +0200 Subject: [PATCH 7/7] refactor: added another check that fixed_num_batches > skip_num_batches --- src/modalities/dataloader/dataloader_factory.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/modalities/dataloader/dataloader_factory.py b/src/modalities/dataloader/dataloader_factory.py index 01261aa6..e7e06940 100644 --- a/src/modalities/dataloader/dataloader_factory.py +++ b/src/modalities/dataloader/dataloader_factory.py @@ -49,6 +49,9 @@ def get_dataloader( start_index=skip_num_batches, underlying_batch_sampler=batch_sampler, max_num_elements=fixed_num_batches ) + if fixed_num_batches <= skip_num_batches: + raise ConfigError("fixed_num_batches must be larger than skip_num_batches") + # make sure that the batch sampler has enough elements such that we can fix the number of batches to num_batches if fixed_num_batches is not None and len(batch_sampler) < fixed_num_batches - skip_num_batches: raise ConfigError(