Skip to content

Commit

Permalink
Merge pull request #180 from Modalities/dataloader_with_fixed_size
Browse files Browse the repository at this point in the history
Dataloader with fixed size
  • Loading branch information
fromm-m authored Jul 22, 2024
2 parents 15ed069 + e9e93ab commit 5fb8b69
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ class LLMDataLoaderConfig(BaseModel):
pin_memory: bool
shuffle: bool
skip_num_batches: Optional[int] = 0
fixed_num_batches: Optional[int] = None


class RepeatingDataLoaderConfig(BaseModel):
Expand Down
22 changes: 21 additions & 1 deletion src/modalities/dataloader/dataloader_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader
from modalities.dataloader.samplers import ResumableBatchSampler
from modalities.exceptions import ConfigError


class DataloaderFactory:
Expand All @@ -18,6 +19,7 @@ def get_dataloader(
pin_memory: bool,
shuffle: bool,
skip_num_batches: Optional[int] = 0,
fixed_num_batches: Optional[int] = None,
) -> LLMDataLoader:
"""Factory method for the instantiation of LLMDataLoader
Expand All @@ -34,11 +36,28 @@ def get_dataloader(
skip_num_batches must not be confused with the number of optimizer steps!
skip_num_batches = num optimizer steps * gradient accumulation steps
Defaults to 0.
fixed_num_batches: (int, optional): Fixed length of the dataloader by cutting off subsequent batches.
Note that these are NOT the global number of batches, but the amount of batches that an
individual rank sees. Make sure that the dataloader has at least fixed_num_batches.
Defaults to None.
Returns:
LLMDataLoader: Instance of LLMDataLoader
"""
batch_sampler = ResumableBatchSampler(start_index=skip_num_batches, underlying_batch_sampler=batch_sampler)

batch_sampler = ResumableBatchSampler(
start_index=skip_num_batches, underlying_batch_sampler=batch_sampler, max_num_elements=fixed_num_batches
)

if fixed_num_batches <= skip_num_batches:
raise ConfigError("fixed_num_batches must be larger than skip_num_batches")

# make sure that the batch sampler has enough elements such that we can fix the number of batches to num_batches
if fixed_num_batches is not None and len(batch_sampler) < fixed_num_batches - skip_num_batches:
raise ConfigError(
f"The dataloader contains only {len(batch_sampler)} batches, which is less than "
f"specified fixed amount of batches of {fixed_num_batches}."
)

dataloader = LLMDataLoader(
dataloader_tag=dataloader_tag,
Expand All @@ -49,6 +68,7 @@ def get_dataloader(
pin_memory=pin_memory,
shuffle=shuffle,
)

return dataloader

@staticmethod
Expand Down
12 changes: 11 additions & 1 deletion src/modalities/dataloader/samplers.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
from typing import Optional

from torch.utils.data import BatchSampler, Sampler


class ResumableBatchSampler(Sampler):
def __init__(self, start_index: int, underlying_batch_sampler: BatchSampler):
def __init__(
self, start_index: int, underlying_batch_sampler: BatchSampler, max_num_elements: Optional[int] = None
):
"""Sampler which starts at a specified batch index and continues sampling for
for a given sampler. Works with normal samplers and BatchSamplers.
Args:
start_index (int): index to start sampling from
existing_sampler (Sampler): Sampler from which we want to continue
max_num_elements (Optional[int]): The maximum number of elements the sampler returns. Default None.
"""

self.start_index = start_index
self.max_num_elements = max_num_elements
self.underlying_batch_sampler = underlying_batch_sampler
# NOTE: we are only iterating ove the indices not the actual data
# so this is relatively cheap
self.indices = list(iter(self.underlying_batch_sampler))
# We discard the samples that come after max_num_elements
# NOTE, that skipping is implemented in __iter__ and __len__.
if self.max_num_elements is not None:
self.indices = self.indices[:max_num_elements]

def __iter__(self):
return iter(self.indices[self.start_index :])
Expand Down
3 changes: 1 addition & 2 deletions src/modalities/utils/number_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def get_local_num_batches_from_num_samples(num_ranks: int, global_num_samples: i
def get_local_num_batches_from_num_tokens(num_ranks: int, global_num_tokens: int, sequence_length: int) -> int:
"""Calculates the number of local batches for each rank, given the global
number of tokens and number of ranks.
This helper function is primarily used to calculate the number of batches to
skip when resuming a dataloader during warmstart.
This helper function is primarily used to calculate a dataloader's number of batches (total and to skip)
Args:
num_ranks (int): _description_
Expand Down
71 changes: 70 additions & 1 deletion tests/dataloader/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import math
from collections.abc import Sequence
from pathlib import Path
from typing import Dict
from typing import Any, Dict, List

import numpy as np
import pytest
import torch
from pydantic import BaseModel
from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler
Expand All @@ -11,11 +13,25 @@
from modalities.config.config import load_app_config_dict
from modalities.config.pydanctic_if_types import PydanticLLMDataLoaderIFType
from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader
from modalities.dataloader.dataset import Dataset
from modalities.dataloader.samplers import ResumableBatchSampler
from modalities.models.gpt2.collator import CollateFnIF
from modalities.registry.components import COMPONENTS
from modalities.registry.registry import Registry


class SequenceDataset(Dataset):
def __init__(self, sequence: Sequence):
super().__init__(raw_data_path=None, sample_key=None)
self.sequence = sequence

def __len__(self) -> int:
return len(self.sequence)

def __getitem__(self, idx: int) -> Any:
return self.sequence[idx]


def test_resumable_dataloader():
batch_size = 3
start_index = 2
Expand Down Expand Up @@ -223,3 +239,56 @@ class DataloaderTestModel(BaseModel):
for batch_1, batch_2 in zip(batches_rank_0, batches_rank_1):
assert ~(batch_1.samples["input_ids"] == batch_2.samples["input_ids"]).all()
assert ~(batch_1.targets["target_ids"] == batch_2.targets["target_ids"]).all()


@pytest.mark.parametrize(
"global_rank",
[0, 1],
)
def test_dataloader_with_fixed_num_batches(global_rank):
class DataloaderTestModel(BaseModel):
train_dataloader: PydanticLLMDataLoaderIFType
fixed_num_batches: int

class IdentityCollateFn(CollateFnIF):
def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> List[Dict[str, torch.Tensor]]:
return batch

root_dir = Path(__file__).parents[0]

config_path = root_dir / "yaml_configs/dataloader_with_fixed_num_batches.yaml"
# we inject a prebuilt dataset and collate_fn, as well as, the global rank constant from outside
dataset = SequenceDataset(list(range(1000)))
config_dict = load_app_config_dict(config_path)
config_dict["settings"]["cuda_env"]["global_rank"] = global_rank
config_dict["train_dataloader"]["config"]["batch_sampler"]["config"]["sampler"]["config"]["rank"] = global_rank
config_dict["train_dataset"] = dataset
config_dict["collate_fn"] = IdentityCollateFn()

# build the remaining components
registry = Registry(COMPONENTS)
component_factory = ComponentFactory(registry=registry)
components: DataloaderTestModel = component_factory.build_components(
config_dict=config_dict, components_model_type=DataloaderTestModel
)
dataloader = components.train_dataloader

# calculate the fixed_num_batches and
# compare it with the one calculated during the component build and the dataloader length
cfg = config_dict["settings"]["training"]
world_size = config_dict["settings"]["cuda_env"]["world_size"]
calculated_fixed_num_batches = cfg["global_num_train_tokens"] // cfg["sequence_length"] // world_size
assert calculated_fixed_num_batches == components.fixed_num_batches
assert len(dataloader) == calculated_fixed_num_batches

# We make sure that the dataloader outputs the correct batches as follows:
# The dataset contains 1000 samples (NOTE that we neglected squence_length and made each sample an integer value)
# we calculated 16 batches above per rank and have 2 ranks in total.
# Therefore the dataloader for rank 0 returns 16 ordered batches of batch_size 2.
# The batches are ordered and not shuffled as per YAML configuration.
# We expect the following output:
# [[0, 2], [4, 6], [8, 10], ..., [56, 58], [60, 62]] (global_rank=0)
# [[1, 3], [5, 7], [9, 11], ..., [57, 59], [61, 63]] (global_rank=1)
calculated_dataloader_content = np.array(list(range(global_rank, 64 + global_rank, 2))).reshape(-1, 2).tolist()
actual_dataloader_content = [i for i in dataloader]
assert calculated_dataloader_content == actual_dataloader_content
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
settings:
referencing_keys:
sample_key: input_ids
target_key: target_ids
training:
local_train_micro_batch_size: 2
global_num_seen_tokens: 0
global_num_train_tokens: 128
sequence_length: 4
cuda_env:
global_rank: 0
world_size: 2

fixed_num_batches:
component_key: number_conversion
variant_key: local_num_batches_from_num_tokens
config:
num_ranks: ${settings.cuda_env.world_size}
global_num_tokens: ${settings.training.global_num_train_tokens}
sequence_length: ${settings.training.sequence_length}

train_dataloader:
component_key: data_loader
variant_key: default
config:
num_workers: 2
pin_memory: true
shuffle: false
dataloader_tag: train
skip_num_batches: 0
fixed_num_batches:
instance_key: fixed_num_batches
pass_type: BY_REFERENCE
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
batch_sampler:
component_key: batch_sampler
variant_key: default
config:
batch_size: ${settings.training.local_train_micro_batch_size}
drop_last: true
sampler:
component_key: sampler
variant_key: distributed_sampler
config:
rank: ${settings.cuda_env.global_rank}
num_replicas: ${settings.cuda_env.world_size}
drop_last: true
shuffle: false
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
collate_fn:
instance_key: collate_fn
pass_type: BY_REFERENCE

0 comments on commit 5fb8b69

Please sign in to comment.