Skip to content

Commit

Permalink
feat: integrated global index creation
Browse files Browse the repository at this point in the history
  • Loading branch information
le1nux committed Jan 13, 2025
1 parent 004daeb commit 8be1ad5
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 24 deletions.
28 changes: 23 additions & 5 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

import click
import click_pathlib
from modalities.utils.logging import get_logger
from pydantic import BaseModel, FilePath

from modalities.api import (
convert_pytorch_to_hf_checkpoint,
create_raw_data_index,
create_global_index,
create_local_index,
create_shuffled_global_index,
generate_text,
merge_packed_data_files,
pack_encoded_data,
Expand All @@ -35,6 +36,7 @@
from modalities.running_env.cuda_env import CudaEnv
from modalities.trainer import Trainer
from modalities.util import get_total_number_of_trainable_parameters, print_rank_0
from modalities.utils.logging import get_logger


@click.group()
Expand Down Expand Up @@ -124,15 +126,15 @@ def data():
pass


@data.command(name="create_raw_index")
@data.command(name="create_local_index")
@click.argument("src_path", type=Path)
@click.option(
"--index_path",
type=Path,
default=None,
help="output path for index. will use parent directory of src_path if none.",
)
def CMD_entry_point_data_create_raw_index(src_path: Path, index_path: Path):
def CMD_entry_point_data_create_local_index(src_path: Path, index_path: Path):
"""Utility CMD for indexing the content of a large jsonl-file.
Background is the ability to further process the respective file without loading it,
while splitting its content line-based. This step is necessary in advance of further processing like tokenization.
Expand All @@ -145,7 +147,23 @@ def CMD_entry_point_data_create_raw_index(src_path: Path, index_path: Path):
Raises:
ValueError: If the index file already exists.
"""
create_raw_data_index(src_path=src_path, index_path=index_path)
create_local_index(src_path=src_path, index_path=index_path)


@data.command(name="create_global_index")
@click.option("--file_list_path", type=Path, required=True)
@click.option("--root_index_path", type=Path, required=True)
@click.option("--global_index_root_path", type=Path, required=True)
def CMD_entry_point_create_global_index(file_list_path: Path, root_index_path: Path, global_index_root_path: Path):
create_global_index(
file_list_path=file_list_path, root_index_path=root_index_path, global_index_root_path=global_index_root_path
)


@data.command(name="create_shuffled_global_index")
@click.option("--global_index_file_path", type=Path, required=True)
def CMD_entry_point_create_shuffled_global_index(global_index_file_path: Path):
create_shuffled_global_index(global_index_file_path=global_index_file_path)


@data.command(name="pack_encoded_data")
Expand Down
12 changes: 11 additions & 1 deletion src/modalities/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class FileExistencePolicy(Enum):
OVERRIDE = "override"


def create_raw_data_index(
def create_local_index(
src_path: Path, index_path: Path, file_existence_policy: FileExistencePolicy = FileExistencePolicy.ERROR
):
"""Creates the index file for the content of a large jsonl-file. The index file
Expand Down Expand Up @@ -71,6 +71,16 @@ def create_raw_data_index(
generator.create_index(index_path)


def create_global_index(file_list_path: Path, root_index_path: Path, global_index_root_path: Path) -> Path:
global_index_file_path = create_global_index(file_list_path, root_index_path, global_index_root_path)
return global_index_file_path


def create_shuffled_global_index(global_index_file_path: Path) -> Path:
global_shuffled_index_file_path = create_shuffled_global_index(global_index_file_path)
return global_shuffled_index_file_path


def generate_text(config_file_path: FilePath):
"""Inference function to generate text with a given model.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import pickle
from pathlib import Path

import numpy as np
import tqdm


def _get_global_index_file_path(global_index_root_path: Path) -> Path:
global_index_file_path = global_index_root_path / f"{global_index_root_path.name}_inorder.idx"
return global_index_file_path


def _get_file_list(file_list_path: Path) -> list[Path]:
file_list: list[Path] = []
with open(file_list_path, "r") as f:
for line in f:
file_list.append(Path(line.strip()))
return file_list


def _get_file_id_file_path_mappings(file_list: list[Path]) -> tuple[dict[Path, int], dict[int, Path]]:
file_path_to_id = {file_path.with_suffix(""): i for i, file_path in enumerate(file_list)}
id_to_file_path = {i: file_path.with_suffix("") for i, file_path in enumerate(file_list)}
return file_path_to_id, id_to_file_path


def _get_local_index_paths(file_list: list[Path], root_index_path: Path, global_index_root_path: Path) -> list[Path]:
local_index_paths = [
path.with_suffix(".idx")
for path in file_list
if (root_index_path / path).is_relative_to(global_index_root_path)
]
return local_index_paths


def _get_total_num_documents(local_index_paths: list[Path], root_index_path: Path) -> int:
num_documents = 0
for local_index_path in tqdm.tqdm(local_index_paths, desc="Counting total number of documents"):
with open(root_index_path / local_index_path, "rb") as f:
index = pickle.load(f)
num_documents += len(index)
return num_documents


def _populate_global_index_array(
global_index_file_path: Path,
num_documents: int,
local_index_paths: list[Path],
root_index_path: Path,
file_path_to_id: dict[Path, int],
) -> np.memmap:
shape = (num_documents + 1, 3)
global_index_array = np.memmap(global_index_file_path, dtype="int64", mode="w+", shape=shape)

# the first row is reserved for the shape of the array and whether rows are shuffled.
# <num rows, num columns, is_shuffled>
global_index_array[0] = np.array([*shape, 0])
start_index = 1
for local_index_path in tqdm.tqdm(local_index_paths, desc="Populating global index array"):
with open(root_index_path / local_index_path, "rb") as f:
local_index = pickle.load(f)

local_index_array = np.array(local_index)
# add the file id to the local index
file_id = file_path_to_id[local_index_path.with_suffix("")]
local_index_array = np.insert(local_index_array, 0, file_id, axis=1)

global_index_array[start_index : start_index + len(local_index_array)] = local_index_array
start_index += len(local_index_array)
global_index_array.flush()
return global_index_array


def create_global_index(file_list_path: Path, root_index_path: Path, global_index_root_path: Path) -> Path:
global_index_file_path = _get_global_index_file_path(global_index_root_path)

file_list = _get_file_list(file_list_path)

file_path_to_id, _ = _get_file_id_file_path_mappings(file_list)
local_index_paths = _get_local_index_paths(file_list, root_index_path, global_index_root_path)
num_documents = _get_total_num_documents(local_index_paths, root_index_path)

_populate_global_index_array(
global_index_file_path, num_documents, local_index_paths, root_index_path, file_path_to_id
)
return global_index_file_path


def create_shuffled_global_index(global_index_file_path: Path) -> Path:
global_shuffled_index_file_path = (
global_index_file_path.parent / f"{global_index_file_path.stem.replace('inorder', 'shuffle_index')}.idx"
)
print(global_shuffled_index_file_path)

# global index array
num_rows, _, _ = np.memmap(global_index_file_path, dtype="int64", mode="r")[0:3]

print(f"Shuffling {num_rows-1} global index indices")
# we count from 1 since the 0th row contains meta information (num_rows, num_cols, is_shuffled)
indices = np.arange(1, num_rows)
np.random.shuffle(indices)

print(f"Writing out shuffled global index array with {num_rows} elements")
global_shuffled_index_array = np.memmap(
global_shuffled_index_file_path, dtype="int64", mode="w+", shape=(len(indices),)
)
chunk_size = 10
for i in tqdm.tqdm(range(0, len(indices), chunk_size)):
chunk_indices = indices[i : i + chunk_size]
global_shuffled_index_array[i : i + len(chunk_indices)] = chunk_indices
global_shuffled_index_array.flush()
return global_shuffled_index_file_path
44 changes: 26 additions & 18 deletions tests/utils/test_number_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from modalities.dataloader.dataset_factory import DatasetFactory
from modalities.utils.number_conversion import NumberConversion
from modalities.utils.number_conversion import TrainingNumberConversion


@pytest.mark.parametrize(
Expand All @@ -15,7 +15,9 @@ def test_get_local_num_batches_from_num_samples(
num_ranks: int, global_num_samples: int, local_micro_batch_size: int, expected: int
):
assert (
NumberConversion.get_local_num_batches_from_num_samples(num_ranks, global_num_samples, local_micro_batch_size)
TrainingNumberConversion.get_local_num_batches_from_num_samples(
num_ranks, global_num_samples, local_micro_batch_size
)
== expected
)

Expand All @@ -28,7 +30,7 @@ def test_get_local_num_batches_from_num_tokens(
num_ranks: int, global_num_tokens: int, sequence_length: int, local_micro_batch_size: int, expected: int
):
assert (
NumberConversion.get_local_num_batches_from_num_tokens(
TrainingNumberConversion.get_local_num_batches_from_num_tokens(
num_ranks, global_num_tokens, sequence_length, local_micro_batch_size
)
== expected
Expand All @@ -47,7 +49,7 @@ def test_get_num_steps_from_num_samples(
expected: int,
):
assert (
NumberConversion.get_num_steps_from_num_samples(
TrainingNumberConversion.get_num_steps_from_num_samples(
num_ranks, local_micro_batch_size, global_num_samples, gradient_accumulation_steps
)
== expected
Expand Down Expand Up @@ -76,7 +78,7 @@ def test_get_num_steps_from_num_tokens(
expected: int,
):
assert (
NumberConversion.get_num_steps_from_num_tokens(
TrainingNumberConversion.get_num_steps_from_num_tokens(
num_ranks, local_micro_batch_size, global_num_tokens, sequence_length, gradient_accumulation_steps
)
== expected
Expand All @@ -101,7 +103,7 @@ def test_get_num_tokens_from_num_steps(
expected: int,
):
assert (
NumberConversion.get_num_tokens_from_num_steps(
TrainingNumberConversion.get_num_tokens_from_num_steps(
num_steps=num_steps,
num_ranks=num_ranks,
local_micro_batch_size=local_micro_batch_size,
Expand Down Expand Up @@ -141,9 +143,9 @@ def test_get_last_step_from_checkpoint_path(checkpoint_path: Path, expected: int
if expected_exception:
# Expecting an exception for this test case
with pytest.raises(expected_exception):
NumberConversion.get_last_step_from_checkpoint_path(checkpoint_path=checkpoint_path)
TrainingNumberConversion.get_last_step_from_checkpoint_path(checkpoint_path=checkpoint_path)
else:
assert NumberConversion.get_last_step_from_checkpoint_path(checkpoint_path=checkpoint_path) == expected
assert TrainingNumberConversion.get_last_step_from_checkpoint_path(checkpoint_path=checkpoint_path) == expected


@pytest.mark.parametrize(
Expand Down Expand Up @@ -175,9 +177,12 @@ def test_get_num_seen_steps_from_checkpoint_path(checkpoint_path: Path, expected
if expected_exception:
# Expecting an exception for this test case
with pytest.raises(expected_exception):
NumberConversion.get_num_seen_steps_from_checkpoint_path(checkpoint_path=checkpoint_path)
TrainingNumberConversion.get_num_seen_steps_from_checkpoint_path(checkpoint_path=checkpoint_path)
else:
assert NumberConversion.get_num_seen_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) == expected
assert (
TrainingNumberConversion.get_num_seen_steps_from_checkpoint_path(checkpoint_path=checkpoint_path)
== expected
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -211,10 +216,10 @@ def test_get_global_num_seen_tokens_from_checkpoint_path(
if expected_exception:
# Expecting an exception for this test case
with pytest.raises(expected_exception):
NumberConversion.get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path)
TrainingNumberConversion.get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path)
else:
assert (
NumberConversion.get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path)
TrainingNumberConversion.get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path)
== expected
)

Expand Down Expand Up @@ -250,10 +255,10 @@ def test_get_global_num_target_tokens_from_checkpoint_path(
if expected_exception:
# Expecting an exception for this test case
with pytest.raises(expected_exception):
NumberConversion.get_global_num_target_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path)
TrainingNumberConversion.get_global_num_target_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path)
else:
assert (
NumberConversion.get_global_num_target_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path)
TrainingNumberConversion.get_global_num_target_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path)
== expected
)

Expand Down Expand Up @@ -287,9 +292,12 @@ def test_get_num_target_steps_from_checkpoint_path(checkpoint_path: Path, expect
if expected_exception:
# Expecting an exception for this test case
with pytest.raises(expected_exception):
NumberConversion.get_num_target_steps_from_checkpoint_path(checkpoint_path=checkpoint_path)
TrainingNumberConversion.get_num_target_steps_from_checkpoint_path(checkpoint_path=checkpoint_path)
else:
assert NumberConversion.get_num_target_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) == expected
assert (
TrainingNumberConversion.get_num_target_steps_from_checkpoint_path(checkpoint_path=checkpoint_path)
== expected
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -336,7 +344,7 @@ def test_get_num_tokens_from_packed_mem_map_dataset_continuous(
)

assert (
NumberConversion.get_num_tokens_from_packed_mem_map_dataset_continuous(
TrainingNumberConversion.get_num_tokens_from_packed_mem_map_dataset_continuous(
dataset_path=dataset_path,
sequence_length=sequence_length,
num_ranks=num_ranks,
Expand Down Expand Up @@ -369,7 +377,7 @@ def test_num_steps_from_raw_dataset_index(
with open(raw_index_path, "rb") as f:
index_length = len(pickle.load(f))

num_steps_from_number_conversion = NumberConversion.get_num_steps_from_raw_dataset_index(
num_steps_from_number_conversion = TrainingNumberConversion.get_num_steps_from_raw_dataset_index(
raw_index_path=raw_index_path,
num_ranks=num_ranks,
local_micro_batch_size=local_micro_batch_size,
Expand Down

0 comments on commit 8be1ad5

Please sign in to comment.