Skip to content

Commit

Permalink
Merge pull request #261 from Modalities/combined_dataset_feature
Browse files Browse the repository at this point in the history
Combined dataset feature
  • Loading branch information
le1nux authored Nov 22, 2024
2 parents e22ebb2 + 7aef91b commit 94cf3f0
Show file tree
Hide file tree
Showing 40 changed files with 644 additions and 791 deletions.
25 changes: 24 additions & 1 deletion CHANGELOG_DEV.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ This [PR](https://github.com/Modalities/modalities/pull/236) removes all code re

**Breaking changes:**
* None
*


## PR #254 Warmstart infrastructure switch

Expand All @@ -85,3 +85,26 @@ This PR mainly addresses the warmstart of model training, e.g., after GPU crashe

**Breaking Changes**
* the settings part of the configs have been completely refactored


## PR #261 Dataloader inefficiencies fix and combined dataset feature

This PR addresses issue #258 (inefficiencies in the dataloader) and additionally introduces a combined dataset, where a dataset can now comprise a list of datasets and iterate over them.
As part of fixing the dataloader inefficiencies, we now implement the sample skipping functionality not on the dataloader level anymore but in an adapted version of the PyTorch `DistributedSampler`. I reran a warm start and the learning is equivalent to a full, non-warmstarted run.

<img width="1415" alt="Screenshot 2024-09-27 at 10 36 19" src="https://github.com/user-attachments/assets/65dfb1ed-e96b-4f50-a127-bc9d240ddff9">


**General Changes**
* Introduced `ResumableDistributedSampler` which is a copy of the PyTorch `DistributedSampler` added with the feature to skip samples. This is from now on used for warmstarts instead of the `skip_num_samples` in the Dataloader. In case of skipping samples, the dataloader had to instantiate a `ResumableBatchSampler` which was internally iterating over all the dataset indices. For small datasets this was fine, but for larger datasets (in the trillion token range) this became a bottleneck at instantiation time:
https://github.com/Modalities/modalities/blob/b79d04d3e92d0845c5ec91f8dd41176fd543cb23/src/modalities/dataloader/samplers.py#L25-L28
Skipping in the `ResumableDistributedSampler` is skipping in O(1) now. The `ResumableBatchSampler` was removed from the codebase.
* Replaced the packed index generation routine (inefficient due to for loop)
https://github.com/Modalities/modalities/blob/b79d04d3e92d0845c5ec91f8dd41176fd543cb23/src/modalities/dataloader/dataset.py#L331-L334
with a vectorized version.
* added new `NumberConversion` routine `num_samples_from_num_tokens `

**Breaking Changes**
* Removed RepeatingDataloader, as a feature that was never actively used for running multiple epochs and had complex maintenance when refactoring the sampling. If needed we could reimpliment it.
* In the settings, the `training_progress` section has now `num_seen_samples` instead of `local_num_seen_batches `, as skipping is now done on the Sampler level and not on the dataloader level anymore
* `batch_size ` and `fast_forward_batch_id ` fields in the `LLMDataLoader ` are not neede anymore and were removed.
14 changes: 7 additions & 7 deletions config_files/training/config_example_coca.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ settings:
training_progress:
global_num_seen_tokens: 0
num_seen_steps: 0
local_num_seen_batches: 0
num_seen_samples: 0
last_step: -1
coca_example_settings:
train_num_samples: 64
Expand Down Expand Up @@ -96,7 +96,6 @@ train_dataloader:
num_workers: 2
pin_memory: true
dataloader_tag: train
skip_num_batches: ${settings.training_progress.local_num_seen_batches}
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
Expand All @@ -108,16 +107,17 @@ train_dataloader:
drop_last: true
sampler:
component_key: sampler
variant_key: distributed_sampler
variant_key: resumable_distributed_sampler
config:
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
rank: ${settings.cuda_env.global_rank}
num_replicas: ${settings.cuda_env.world_size}
shuffle: true
drop_last: true
seed: 42
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
drop_last: true
skip_num_global_samples: ${settings.training_progress.num_seen_samples}
collate_fn:
instance_key: collate_fn
pass_type: BY_REFERENCE
Expand Down
16 changes: 8 additions & 8 deletions config_files/training/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ settings:
local_train_micro_batch_size: 1
sequence_length: 256
training_target:
num_target_tokens:
num_target_tokens:
component_key: number_conversion
variant_key: num_tokens_from_packed_mem_map_dataset_continuous
config:
Expand All @@ -47,7 +47,7 @@ settings:
training_progress:
global_num_seen_tokens: 0
num_seen_steps: 0
local_num_seen_batches: 0
num_seen_samples: 0
last_step: -1

collate_fn:
Expand All @@ -72,7 +72,6 @@ train_dataloader:
num_workers: 2
pin_memory: true
dataloader_tag: train
skip_num_batches: ${settings.training_progress.local_num_seen_batches}
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
Expand All @@ -84,16 +83,17 @@ train_dataloader:
drop_last: true
sampler:
component_key: sampler
variant_key: distributed_sampler
variant_key: resumable_distributed_sampler
config:
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
rank: ${settings.cuda_env.global_rank}
num_replicas: ${settings.cuda_env.world_size}
shuffle: true
drop_last: true
seed: 42
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
drop_last: true
skip_num_global_samples: ${settings.training_progress.num_seen_samples}
collate_fn:
instance_key: collate_fn
pass_type: BY_REFERENCE
Expand Down
3 changes: 2 additions & 1 deletion docs/components/components.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
| dataset | mem_map_dataset | [DatasetFactory.get_mem_map_dataset](../../src/modalities/dataloader/dataset_factory.py)| [MemMapDatasetConfig](../../src/modalities/config/config.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | MemMap Dataset |
| dataset | packed_mem_map_dataset_continuous | [DatasetFactory.get_packed_mem_map_dataset_continuous](../../src/modalities/dataloader/dataset_factory.py)| [PackedMemMapDatasetContinuousConfig](../../src/modalities/config/config.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | Packed Memory Mapped Dataset Continuous |
| dataset | dummy_dataset | [DatasetFactory.get_dummy_dataset](../../src/modalities/dataloader/dataset_factory.py)| [DummyDatasetConfig](../../src/modalities/dataloader/dataset.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | Dummy dataset creating random samples of specified shape |
| dataset | combined | [DatasetFactory.get_combined_dataset](../../src/modalities/dataloader/dataset_factory.py)| [CombinedDatasetConfig](../../src/modalities/dataloader/dataset.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | Dataset implementation combining multiple datasets into one. |

## Data sampling

Expand All @@ -76,7 +77,6 @@
|Component type | Component Version | Implementation | Configuration | Component Interface | Description |
|---------------|--------------------|----------------|---------------|---------------------|-------------|
| data_loader | default | [DataloaderFactory.get_dataloader](../../src/modalities/dataloader/dataloader_factory.py)| [LLMDataLoaderConfig](s../../src/modalities/config/config.py) | [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) | LLM Data loader extending pytorch data loader functionality |
| data_loader | repeating_data_loader | [DataloaderFactory.get_repeating_dataloader](../../src/modalities/dataloader/dataloader_factory.py)| [RepeatingDataLoaderConfig](../../src/modalities/config/config.py) | [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) | Data loader that repeats the given dataloader for the specified number of epochs. |

## Checkpointing

Expand Down Expand Up @@ -118,6 +118,7 @@
|---------------|--------------------|----------------|---------------|---------------------|-------------|
| number_conversion | local_num_batches_from_num_samples | [NumberConversion.get_local_num_batches_from_num_samples](../../src/modalities/utils/number_conversion.py)| [LocalNumBatchesFromNumSamplesConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of local batches for each rank, given the global number of samples and number of ranks. |
| number_conversion | local_num_batches_from_num_tokens | [NumberConversion.get_local_num_batches_from_num_tokens](../../src/modalities/utils/number_conversion.py)| [LocalNumBatchesFromNumTokensConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of local batches for each rank, given the global number of tokens and number of ranks. |
| number_conversion | local_num_batches_from_num_tokens | [NumberConversion.get_num_samples_from_num_tokens](../../src/modalities/utils/number_conversion.py)| [NumSamplesFromNumTokensConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of global samples, given the global number of tokens and sequence length |
| number_conversion | num_steps_from_num_samples | [NumberConversion.get_num_steps_from_num_samples](../../src/modalities/utils/number_conversion.py)| [NumStepsFromNumSamplesConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of steps given the global number of samples, local micro batch size and number of ranks. |
| number_conversion | num_steps_from_num_tokens | [NumberConversion.get_num_steps_from_num_tokens](../../src/modalities/utils/number_conversion.py)| [NumStepsFromNumTokensConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of steps given the global number of tokens, local micro batch size and number of ranks. |
| number_conversion | num_tokens_from_num_steps | [NumberConversion.get_num_tokens_from_num_steps](../../src/modalities/utils/number_conversion.py)| [NumTokensFromNumStepsConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of tokens from the number of steps, number of ranks, local micro batch size, global number of tokens, squence length and gradient accumulation steps |
Expand Down
2 changes: 1 addition & 1 deletion src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def entry_point_data_create_raw_index(src_path: Path, index_path: Path):

index_path = LargeFileLinesReader.default_index_path(src_path, index_path)
if index_path.exists():
raise ValueError("index already exists. delete it or specify different output folder.")
raise ValueError(f"Index already exists in {index_path}. Delete it or specify different output folder.")

print(f"reading raw data from {src_path}")
print(f"writing index to {index_path}")
Expand Down
28 changes: 15 additions & 13 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,17 @@ class DistributedSamplerConfig(BaseModel):
drop_last: Literal[True] = True


class ResumableDistributedSamplerConfig(BaseModel):
dataset: PydanticDatasetIFType
rank: Annotated[int, Field(strict=True, ge=0)]
num_replicas: Annotated[int, Field(strict=True, ge=0)] = None
epoch: Annotated[int, Field(strict=True, ge=0)] = 0
shuffle: Optional[bool] = False
seed: Optional[int] = 0
drop_last: Literal[True] = True
skip_num_global_samples: Annotated[int, Field(strict=True, ge=0)] = 0


class MemMapDatasetConfig(BaseModel):
raw_data_path: FilePath
index_path: Optional[FilePath] = None
Expand All @@ -285,17 +296,16 @@ class PackedMemMapDatasetMegatronConfig(BaseModel):
sample_key: str


class CombinedDatasetConfig(BaseModel):
datasets: list[PydanticDatasetIFType]


class BatchSamplerConfig(BaseModel):
sampler: PydanticSamplerIFType
batch_size: Annotated[int, Field(strict=True, gt=0)]
drop_last: Literal[True] = True


class ResumableBatchSamplerConfig(BaseModel):
sampler: PydanticSamplerIFType
start_index: Annotated[int, Field(strict=True, gt=0)]


class GPT2LLMCollateFnConfig(BaseModel):
sample_key: str
target_key: str
Expand All @@ -308,14 +318,6 @@ class LLMDataLoaderConfig(BaseModel):
collate_fn: Optional[PydanticCollateFnIFType] = None
num_workers: Annotated[int, Field(strict=True, ge=0)]
pin_memory: bool
skip_num_batches: Optional[int] = 0
fixed_num_batches: Optional[int] = None


class RepeatingDataLoaderConfig(BaseModel):
dataloader: PydanticLLMDataLoaderIFType
reshuffle_after_epoch: Optional[bool] = False
num_epochs: Annotated[int, Field(strict=True, ge=1)]


class DummyProgressSubscriberConfig(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion src/modalities/config/instantiation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class TrainingTarget(BaseModel):
class TrainingProgress(BaseModel):
global_num_seen_tokens: Annotated[int, Field(strict=True, ge=0)]
num_seen_steps: Annotated[int, Field(strict=True, ge=0)]
local_num_seen_batches: Annotated[int, Field(strict=True, ge=0)]
num_seen_samples: Annotated[int, Field(strict=True, ge=0)]
last_step: Annotated[int, Field(strict=True, ge=-1)]


Expand Down
28 changes: 21 additions & 7 deletions src/modalities/dataloader/create_packed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,13 @@ class EmbeddedStreamData:
TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES = 4
HEADER_SIZE_IN_BYTES = DATA_SECTION_LENGTH_IN_BYTES + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES

def __init__(self, data_path: Path):
def __init__(self, data_path: Path, load_index: Optional[bool] = True):
"""
Initializes an EmbeddedStreamData object.
Args:
data_path (Path): The path to the packed data file.
load_index (bool, optional): Whether to load the index. Defaults to True.
Raises:
FileNotFoundError: If the packed data file is not found at the specified path.
Expand All @@ -352,14 +353,27 @@ def __init__(self, data_path: Path):
self.token_size_in_bytes = int.from_bytes(token_size_as_bytes, byteorder="little", signed=False)

# get index
f.seek(self.HEADER_SIZE_IN_BYTES + self.data_len)
pkl_encoded_index = f.read()
# contains the start offset and length of each segment
# as byte positions in the data section
self.index_base: list[tuple[int, int]] = pickle.loads(pkl_encoded_index)
if load_index:
f.seek(self.HEADER_SIZE_IN_BYTES + self.data_len)
pkl_encoded_index = f.read()
# contains the start offset and length of each segment
# as byte positions in the data section
self._index_base: list[tuple[int, int]] = pickle.loads(pkl_encoded_index)
else:
self._index_base = None

# initialize memmapped data section
self.data = np.memmap(self._data_path, mode="r", offset=self.HEADER_SIZE_IN_BYTES, shape=(self.data_len,))
self._data = np.memmap(self._data_path, mode="r", offset=self.HEADER_SIZE_IN_BYTES, shape=(self.data_len,))

@property
def index_base(self) -> list[tuple[int, int]]:
if self._index_base is None:
raise ValueError("Index was not loaded. Set `load_index=True` during initialization.")
return self._index_base

@property
def data(self) -> np.ndarray:
return self._data


def join_embedded_stream_data(stream_data: list[EmbeddedStreamData], target_file: Path, chunk_size: int = 2048):
Expand Down
Loading

0 comments on commit 94cf3f0

Please sign in to comment.