From 5890220084ed03c92bde3ff7858ffad370aefe4b Mon Sep 17 00:00:00 2001 From: mali-git Date: Tue, 21 Jan 2025 11:46:02 +0100 Subject: [PATCH] test: save tokenized input --- tests/conftest.py | 14 +++++++++ .../dataloader/test_shuffle_tokenized_data.py | 31 +++++++++++++------ 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c05cb2c8..694147b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import dataclasses import os import pickle +import string from pathlib import Path from unittest.mock import MagicMock @@ -12,6 +13,7 @@ from modalities.checkpointing.checkpoint_saving import CheckpointSaving from modalities.config.config import load_app_config_dict from modalities.dataloader.create_index import IndexGenerator +from modalities.dataloader.create_packed_data import PackedDataGenerator from modalities.dataloader.dataloader import LLMDataLoader from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader from modalities.evaluator import Evaluator @@ -223,3 +225,15 @@ def torch_distributed_cleanup(): else: # see https://pytorch.org/docs/2.4/_modules/torch/cuda.html#device_count torch.cuda._cached_device_count = None + + +@pytest.fixture +def encoding_set_up(): + # Define the vocabulary + vocabulary = {char: idx for idx, char in enumerate(string.ascii_lowercase)} + + # Ensure num_bytes_per_token is valid + num_bytes_per_token = PackedDataGenerator._get_required_num_of_bytes_to_repr(len(vocabulary)) + assert num_bytes_per_token == 1 # This assertion will fail within the test framework if incorrect + + return vocabulary, num_bytes_per_token diff --git a/tests/dataloader/test_shuffle_tokenized_data.py b/tests/dataloader/test_shuffle_tokenized_data.py index 0da79d90..c7330360 100644 --- a/tests/dataloader/test_shuffle_tokenized_data.py +++ b/tests/dataloader/test_shuffle_tokenized_data.py @@ -4,10 +4,21 @@ from modalities.dataloader.shuffle_tokenized_data import _process_batch, shuffle_tokenized_data -def test_process_batch_with_embedded_stream_with_memmap(tmp_path): +def _tokenize(text: str, vocabulary: dict[str, int]) -> list[int]: + text = text.lower() + return [vocabulary[char] for char in text] + + +def _convert_tokens_to_bytes(tokens: list[int], num_bytes_per_token: int) -> bytes: + return b"".join([token.to_bytes(num_bytes_per_token, byteorder="little", signed=False) for token in tokens]) + + +def test_process_batch(tmp_path, encoding_set_up): + vocabulary, num_bytes_per_token = encoding_set_up # Create a temporary file file_path = tmp_path / "test_data.pbin" - data = b"IloveModalities" # Example data + data = _tokenize(text="IloveModalities", vocabulary=vocabulary) + data = _convert_tokens_to_bytes(data, num_bytes_per_token=num_bytes_per_token) with open(file_path, "wb") as f: f.write(data) @@ -23,25 +34,27 @@ def test_process_batch_with_embedded_stream_with_memmap(tmp_path): new_data, new_index = _process_batch(batch=batch, data=in_memory_data, start_position=0) # Validate the result - expected_data = b"IloveModalities" - expected_index = [(0, 1), (1, 4), (5, 10)] + expected_data = data + expected_index = batch assert (new_data, new_index) == (expected_data, expected_index) -def test_shuffle_tokenized_data(tmp_path): +def test_shuffle_tokenized_data(tmp_path, encoding_set_up): + vocabulary, num_bytes_per_token = encoding_set_up # Create test input data - data = b"IloveModalities" + data = _tokenize(text="IloveModalities", vocabulary=vocabulary) + data = _convert_tokens_to_bytes(data, num_bytes_per_token=num_bytes_per_token) data_section_length_as_bytes = len(data).to_bytes( EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little" ) - token_size_in_bytes = 4 - token_size_as_bytes = token_size_in_bytes.to_bytes( + token_size_as_bytes = num_bytes_per_token.to_bytes( EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="little" ) index = [(0, 1), (1, 4), (5, 10)] # Prepare the input file input_path = tmp_path / "input.pbin" + output_path = tmp_path / "output.pbin" with input_path.open("wb") as f: f.write(data_section_length_as_bytes) f.write(token_size_as_bytes) @@ -51,7 +64,7 @@ def test_shuffle_tokenized_data(tmp_path): for batch_size in [1, 2, 3]: # Call shuffle_tokenized_data output_path = tmp_path / "input_shuffled.pbin" - shuffle_tokenized_data(input_path, batch_size=batch_size) + shuffle_tokenized_data(input_data_path=input_path, output_data_path=output_path, batch_size=batch_size) # Validate the output assert output_path.is_file()