-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #290 from Modalities/shuffle_tokenized_data
Shuffle Tokenized Data
- Loading branch information
Showing
4 changed files
with
226 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import pickle | ||
import random | ||
from pathlib import Path | ||
|
||
from modalities.dataloader.create_packed_data import EmbeddedStreamData | ||
|
||
|
||
def _process_batch( | ||
batch: list[tuple[int, int]], data: bytes, start_position: int | ||
) -> tuple[bytes, list[tuple[int, int]]]: | ||
"""Process a batch of index entries to extract documents and create a new index. | ||
Args: | ||
batch (list[tuple[int, int]]): List of index entries [(start, length), ...]. | ||
data (bytes): Byte stream of the entire data loaded in memory. | ||
start_position (int): The starting position for this batch in the byte stream. | ||
Returns: | ||
tuple[bytes, list[tuple[int, int]]]: A tuple containing the processed data (bytes) | ||
and the new index [(position, length), ...]. | ||
""" | ||
processed_data = [] | ||
new_index = [] | ||
|
||
current_position = start_position | ||
|
||
for start, length in batch: | ||
# Access the data slice directly from the in-memory bytes | ||
document = data[start : start + length] | ||
processed_data.append(document) # Already bytes | ||
|
||
# Record the current position and length in the new index | ||
new_index.append((current_position, length)) | ||
current_position += length | ||
|
||
return b"".join(processed_data), new_index | ||
|
||
|
||
def shuffle_tokenized_data(input_data_path: Path, output_data_path: Path, batch_size: int) -> None: | ||
"""Shuffles a tokenized file (.pbin). | ||
Shuffled data is written to the specified output file. | ||
Note that the tokenized data is fully materialized in-memory. | ||
Args: | ||
input_data_path (Path): Path to the tokenized data (.pbin). | ||
output_data_path (Path): Path to write the shuffled tokenized data. | ||
batch_size (int): Number of documents to process per batch. | ||
Returns: | ||
None | ||
""" | ||
# Step 1: Load the entire data into memory | ||
with input_data_path.open("rb") as f: | ||
# Read the header | ||
data_section_length_in_bytes = f.read(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES) | ||
data_len = int.from_bytes(data_section_length_in_bytes, byteorder="little") | ||
|
||
token_size_as_bytes = f.read(EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES) | ||
|
||
# Load the data | ||
data = f.read(data_len) | ||
|
||
# Load the index | ||
pkl_encoded_index = f.read() | ||
index_base = pickle.loads(pkl_encoded_index) | ||
|
||
# Step 2: Shuffle the index | ||
random.shuffle(index_base) | ||
|
||
# Step 3: Divide the shuffled index into batches | ||
batches: list[list[tuple[int, int]]] = [index_base[i : i + batch_size] for i in range(0, len(index_base), batch_size)] | ||
|
||
header_data = data_section_length_in_bytes + token_size_as_bytes | ||
|
||
with output_data_path.open("wb") as f: | ||
# Write the header data | ||
f.write(header_data) | ||
current_position = 0 | ||
final_index = [] | ||
|
||
# Process and write each batch sequentially | ||
for batch in batches: | ||
data_segment, new_index = _process_batch(batch, data, current_position) | ||
f.write(data_segment) | ||
final_index.extend(new_index) | ||
current_position += len(data_segment) | ||
|
||
# Write the final index to the file | ||
f.write(pickle.dumps(final_index)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import pickle | ||
|
||
from modalities.dataloader.create_packed_data import EmbeddedStreamData | ||
from modalities.dataloader.shuffle_tokenized_data import _process_batch, shuffle_tokenized_data | ||
|
||
|
||
def _tokenize(text: str, vocabulary: dict[str, int]) -> list[int]: | ||
text = text.lower() | ||
return [vocabulary[char] for char in text] | ||
|
||
|
||
def _convert_tokens_to_bytes(tokens: list[int], num_bytes_per_token: int) -> bytes: | ||
return b"".join([token.to_bytes(num_bytes_per_token, byteorder="little", signed=False) for token in tokens]) | ||
|
||
|
||
def test_process_batch(tmp_path, encoding_set_up): | ||
vocabulary, num_bytes_per_token = encoding_set_up | ||
# Create a temporary file | ||
file_path = tmp_path / "test_data.pbin" | ||
data = _tokenize(text="IloveModalities", vocabulary=vocabulary) | ||
data = _convert_tokens_to_bytes(data, num_bytes_per_token=num_bytes_per_token) | ||
|
||
with open(file_path, "wb") as f: | ||
f.write(data) | ||
|
||
# Load the data into memory | ||
with open(file_path, "rb") as f: | ||
in_memory_data = f.read() | ||
|
||
# Define a batch | ||
batch = [(0, 1), (1, 4), (5, 10)] | ||
|
||
# Call the function | ||
new_data, new_index = _process_batch(batch=batch, data=in_memory_data, start_position=0) | ||
|
||
# Validate the result | ||
expected_data = data | ||
expected_index = batch | ||
assert (new_data, new_index) == (expected_data, expected_index) | ||
|
||
|
||
def test_shuffle_tokenized_data(tmp_path, encoding_set_up): | ||
vocabulary, num_bytes_per_token = encoding_set_up | ||
# Create test input data | ||
data = _tokenize(text="IloveModalities", vocabulary=vocabulary) | ||
data = _convert_tokens_to_bytes(data, num_bytes_per_token=num_bytes_per_token) | ||
data_section_length_as_bytes = len(data).to_bytes( | ||
EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little" | ||
) | ||
token_size_as_bytes = num_bytes_per_token.to_bytes( | ||
EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="little" | ||
) | ||
index = [(0, 1), (1, 4), (5, 10)] | ||
|
||
# Prepare the input file | ||
input_path = tmp_path / "input.pbin" | ||
output_path = tmp_path / "output.pbin" | ||
with input_path.open("wb") as f: | ||
f.write(data_section_length_as_bytes) | ||
f.write(token_size_as_bytes) | ||
f.write(data) | ||
f.write(pickle.dumps(index)) | ||
|
||
for batch_size in [1, 2, 3]: | ||
# Call shuffle_tokenized_data | ||
output_path = tmp_path / "input_shuffled.pbin" | ||
shuffle_tokenized_data(input_data_path=input_path, output_data_path=output_path, batch_size=batch_size) | ||
|
||
# Validate the output | ||
assert output_path.is_file() | ||
|
||
with output_path.open("rb") as f: | ||
# Validate header and data | ||
data_section_length_as_bytes_written = f.read(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES) | ||
assert data_section_length_as_bytes_written == data_section_length_as_bytes | ||
assert f.read(len(token_size_as_bytes)) == token_size_as_bytes | ||
data_len = int.from_bytes(data_section_length_as_bytes, byteorder="little") | ||
data_written = f.read(data_len) | ||
|
||
# Validate the shuffled index | ||
written_index = pickle.loads(f.read()) | ||
|
||
# Extract substrings from the data using written_index | ||
extracted_substrings = [data_written[start : start + length] for start, length in written_index] | ||
|
||
# Verify that these substrings match the original defined ones | ||
original_substrings = [data[start : start + length] for start, length in index] | ||
|
||
# Ensure that extracted substrings are a valid permutation of original substrings | ||
assert sorted(extracted_substrings) == sorted(original_substrings) |