Skip to content

Commit

Permalink
Merge pull request #290 from Modalities/shuffle_tokenized_data
Browse files Browse the repository at this point in the history
Shuffle Tokenized Data
  • Loading branch information
mali-git authored Jan 21, 2025
2 parents 2673e1c + 3357f1a commit 43b6cc2
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from modalities.config.component_factory import ComponentFactory
from modalities.config.config import ProcessGroupBackendType, load_app_config_dict
from modalities.config.instantiation_models import TrainingComponentsInstantiationModel, TrainingReportGenerator
from modalities.dataloader.shuffle_tokenized_data import shuffle_tokenized_data
from modalities.evaluator import Evaluator
from modalities.gym import Gym
from modalities.logging_broker.message_broker import MessageBroker
Expand Down Expand Up @@ -181,6 +182,36 @@ def CMD_entry_point_merge_packed_data(src_paths: list[Path], target_path: Path):
merge_packed_data_files(src_paths=src_paths, target_path=target_path)


@data.command(name="shuffle_tokenized_data")
@click.option(
"--input_data_path",
type=click_pathlib.Path(exists=False),
required=True,
help="Path to a tokenized file (.pbin).",
)
@click.option(
"--output_data_path",
type=click_pathlib.Path(exists=False),
required=True,
help="Path to write the shuffled tokenized data (.pbin).",
)
@click.option(
"--batch-size", type=int, default=100, show_default=True, help="Number of documents to process per batch."
)
def CMD_shuffle_tokenized_data(input_data_path: Path, output_data_path: Path, batch_size: int) -> None:
"""Entrypoint for shuffling tokenized data.
Args:
input_data_path (Path): The path to the input tokenized data (.pbin).
output_data_path (Path): Path to write the shuffled tokenized data (.pbin).
batch_size (int): The size of the batches to shuffle.
Returns:
None
"""
shuffle_tokenized_data(input_data_path=input_data_path, output_data_path=output_data_path, batch_size=batch_size)


class Main:
"""Main class that orchestrates the training process."""

Expand Down
91 changes: 91 additions & 0 deletions src/modalities/dataloader/shuffle_tokenized_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pickle
import random
from pathlib import Path

from modalities.dataloader.create_packed_data import EmbeddedStreamData


def _process_batch(
batch: list[tuple[int, int]], data: bytes, start_position: int
) -> tuple[bytes, list[tuple[int, int]]]:
"""Process a batch of index entries to extract documents and create a new index.
Args:
batch (list[tuple[int, int]]): List of index entries [(start, length), ...].
data (bytes): Byte stream of the entire data loaded in memory.
start_position (int): The starting position for this batch in the byte stream.
Returns:
tuple[bytes, list[tuple[int, int]]]: A tuple containing the processed data (bytes)
and the new index [(position, length), ...].
"""
processed_data = []
new_index = []

current_position = start_position

for start, length in batch:
# Access the data slice directly from the in-memory bytes
document = data[start : start + length]
processed_data.append(document) # Already bytes

# Record the current position and length in the new index
new_index.append((current_position, length))
current_position += length

return b"".join(processed_data), new_index


def shuffle_tokenized_data(input_data_path: Path, output_data_path: Path, batch_size: int) -> None:
"""Shuffles a tokenized file (.pbin).
Shuffled data is written to the specified output file.
Note that the tokenized data is fully materialized in-memory.
Args:
input_data_path (Path): Path to the tokenized data (.pbin).
output_data_path (Path): Path to write the shuffled tokenized data.
batch_size (int): Number of documents to process per batch.
Returns:
None
"""
# Step 1: Load the entire data into memory
with input_data_path.open("rb") as f:
# Read the header
data_section_length_in_bytes = f.read(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES)
data_len = int.from_bytes(data_section_length_in_bytes, byteorder="little")

token_size_as_bytes = f.read(EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES)

# Load the data
data = f.read(data_len)

# Load the index
pkl_encoded_index = f.read()
index_base = pickle.loads(pkl_encoded_index)

# Step 2: Shuffle the index
random.shuffle(index_base)

# Step 3: Divide the shuffled index into batches
batches: list[list[tuple[int, int]]] = [index_base[i : i + batch_size] for i in range(0, len(index_base), batch_size)]

header_data = data_section_length_in_bytes + token_size_as_bytes

with output_data_path.open("wb") as f:
# Write the header data
f.write(header_data)
current_position = 0
final_index = []

# Process and write each batch sequentially
for batch in batches:
data_segment, new_index = _process_batch(batch, data, current_position)
f.write(data_segment)
final_index.extend(new_index)
current_position += len(data_segment)

# Write the final index to the file
f.write(pickle.dumps(final_index))
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
import os
import pickle
import string
from pathlib import Path
from unittest.mock import MagicMock

Expand All @@ -12,6 +13,7 @@
from modalities.checkpointing.checkpoint_saving import CheckpointSaving
from modalities.config.config import load_app_config_dict
from modalities.dataloader.create_index import IndexGenerator
from modalities.dataloader.create_packed_data import PackedDataGenerator
from modalities.dataloader.dataloader import LLMDataLoader
from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader
from modalities.evaluator import Evaluator
Expand Down Expand Up @@ -223,3 +225,15 @@ def torch_distributed_cleanup():
else:
# see https://pytorch.org/docs/2.4/_modules/torch/cuda.html#device_count
torch.cuda._cached_device_count = None


@pytest.fixture
def encoding_set_up():
# Define the vocabulary
vocabulary = {char: idx for idx, char in enumerate(string.ascii_lowercase)}

# Ensure num_bytes_per_token is valid
num_bytes_per_token = PackedDataGenerator._get_required_num_of_bytes_to_repr(len(vocabulary))
assert num_bytes_per_token == 1 # This assertion will fail within the test framework if incorrect

return vocabulary, num_bytes_per_token
90 changes: 90 additions & 0 deletions tests/dataloader/test_shuffle_tokenized_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pickle

from modalities.dataloader.create_packed_data import EmbeddedStreamData
from modalities.dataloader.shuffle_tokenized_data import _process_batch, shuffle_tokenized_data


def _tokenize(text: str, vocabulary: dict[str, int]) -> list[int]:
text = text.lower()
return [vocabulary[char] for char in text]


def _convert_tokens_to_bytes(tokens: list[int], num_bytes_per_token: int) -> bytes:
return b"".join([token.to_bytes(num_bytes_per_token, byteorder="little", signed=False) for token in tokens])


def test_process_batch(tmp_path, encoding_set_up):
vocabulary, num_bytes_per_token = encoding_set_up
# Create a temporary file
file_path = tmp_path / "test_data.pbin"
data = _tokenize(text="IloveModalities", vocabulary=vocabulary)
data = _convert_tokens_to_bytes(data, num_bytes_per_token=num_bytes_per_token)

with open(file_path, "wb") as f:
f.write(data)

# Load the data into memory
with open(file_path, "rb") as f:
in_memory_data = f.read()

# Define a batch
batch = [(0, 1), (1, 4), (5, 10)]

# Call the function
new_data, new_index = _process_batch(batch=batch, data=in_memory_data, start_position=0)

# Validate the result
expected_data = data
expected_index = batch
assert (new_data, new_index) == (expected_data, expected_index)


def test_shuffle_tokenized_data(tmp_path, encoding_set_up):
vocabulary, num_bytes_per_token = encoding_set_up
# Create test input data
data = _tokenize(text="IloveModalities", vocabulary=vocabulary)
data = _convert_tokens_to_bytes(data, num_bytes_per_token=num_bytes_per_token)
data_section_length_as_bytes = len(data).to_bytes(
EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little"
)
token_size_as_bytes = num_bytes_per_token.to_bytes(
EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="little"
)
index = [(0, 1), (1, 4), (5, 10)]

# Prepare the input file
input_path = tmp_path / "input.pbin"
output_path = tmp_path / "output.pbin"
with input_path.open("wb") as f:
f.write(data_section_length_as_bytes)
f.write(token_size_as_bytes)
f.write(data)
f.write(pickle.dumps(index))

for batch_size in [1, 2, 3]:
# Call shuffle_tokenized_data
output_path = tmp_path / "input_shuffled.pbin"
shuffle_tokenized_data(input_data_path=input_path, output_data_path=output_path, batch_size=batch_size)

# Validate the output
assert output_path.is_file()

with output_path.open("rb") as f:
# Validate header and data
data_section_length_as_bytes_written = f.read(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES)
assert data_section_length_as_bytes_written == data_section_length_as_bytes
assert f.read(len(token_size_as_bytes)) == token_size_as_bytes
data_len = int.from_bytes(data_section_length_as_bytes, byteorder="little")
data_written = f.read(data_len)

# Validate the shuffled index
written_index = pickle.loads(f.read())

# Extract substrings from the data using written_index
extracted_substrings = [data_written[start : start + length] for start, length in written_index]

# Verify that these substrings match the original defined ones
original_substrings = [data[start : start + length] for start, length in index]

# Ensure that extracted substrings are a valid permutation of original substrings
assert sorted(extracted_substrings) == sorted(original_substrings)

0 comments on commit 43b6cc2

Please sign in to comment.