Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into v0.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
ylgh committed Oct 27, 2022
2 parents 8e4a255 + daabdf4 commit d279c56
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 45 deletions.
63 changes: 37 additions & 26 deletions torchrec/datasets/criteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ def get_file_idx_to_row_range(
lengths: List[int],
rank: int,
world_size: int,
start_row: int = 0,
last_row: Optional[int] = None,
) -> Dict[int, Tuple[int, int]]:
"""
Given a rank, world_size, and the lengths (number of rows) for a list of files,
Expand All @@ -296,14 +298,26 @@ def get_file_idx_to_row_range(
# All ..._g variables are globals indices (meaning they range from 0 to
# total_length - 1). All ..._l variables are local indices (meaning they range
# from 0 to lengths[i] - 1 for the ith file).

total_length = sum(lengths)
if last_row is None:
total_length = sum(lengths) - start_row
else:
total_length = last_row - start_row + 1
rows_per_rank = total_length // world_size
remainder = total_length % world_size

# Global indices that rank is responsible for. All ranges (left, right) are
# inclusive.
rank_left_g = rank * rows_per_rank
rank_right_g = (rank + 1) * rows_per_rank - 1
if rank < remainder:
rank_left_g = rank * (rows_per_rank + 1)
rank_right_g = (rank + 1) * (rows_per_rank + 1) - 1
else:
rank_left_g = (
remainder * (rows_per_rank + 1) + (rank - remainder) * rows_per_rank
)
rank_right_g = rank_left_g + rows_per_rank - 1

rank_left_g += start_row
rank_right_g += start_row

output = {}

Expand Down Expand Up @@ -734,34 +748,31 @@ def __init__(
}

def _load_data_for_rank(self) -> None:
if self.stage == "train":
file_idx_to_row_range = BinaryCriteoUtils.get_file_idx_to_row_range(
lengths=[
BinaryCriteoUtils.get_shape_from_npy(
path, path_manager_key=self.path_manager_key
)[0]
for path in self.dense_paths
],
rank=self.rank,
world_size=self.world_size,
)
elif self.stage in ["val", "test"]:
start_row, last_row = 0, None
if self.stage in ["val", "test"]:
# Last day's dataset is split into 2 sets: 1st half for "val"; 2nd for "test"
samples_in_file = BinaryCriteoUtils.get_shape_from_npy(
self.dense_paths[0], path_manager_key=self.path_manager_key
)[0]

dataset_start = 0
start_row = 0
dataset_len = int(np.ceil(samples_in_file / 2.0))

if self.stage == "test":
dataset_start = dataset_len
dataset_len = samples_in_file - dataset_len
segment_len = dataset_len // self.world_size
rank_start_row = dataset_start + self.rank * segment_len

rank_last_row = rank_start_row + segment_len - 1
file_idx_to_row_range = {0: (rank_start_row, rank_last_row)}
start_row = dataset_len
dataset_len = samples_in_file - start_row
last_row = start_row + dataset_len - 1

file_idx_to_row_range = BinaryCriteoUtils.get_file_idx_to_row_range(
lengths=[
BinaryCriteoUtils.get_shape_from_npy(
path, path_manager_key=self.path_manager_key
)[0]
for path in self.dense_paths
],
rank=self.rank,
world_size=self.world_size,
start_row=start_row,
last_row=last_row,
)

self.dense_arrs, self.sparse_arrs, self.labels_arrs = [], [], []
for arrs, paths in zip(
Expand Down
38 changes: 19 additions & 19 deletions torchrec/datasets/tests/test_criteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

import contextlib
import math
import os
import random
import tempfile
Expand Down Expand Up @@ -239,12 +240,9 @@ def _validate_sparse_to_contiguous_preproc(
input_files, temp_output_dir, freq_threshold, columns
)

output_files = list(
map(
lambda f: os.path.join(temp_output_dir, f),
os.listdir(temp_output_dir),
)
)
output_files = [
os.path.join(temp_output_dir, f) for f in os.listdir(temp_output_dir)
]
output_files.sort()
for day, file in enumerate(output_files):
processed_data = np.load(file)
Expand Down Expand Up @@ -280,9 +278,9 @@ def test_shuffle(self) -> None:
labels_data = [np.array([[i], [i + 3], [i + 6]]) for i in range(3)]

def save_data_list(data: List[np.ndarray], data_type: str) -> None:
for day, data in enumerate(data):
for day, data_ in enumerate(data):
file = os.path.join(temp_input_dir, f"day_{day}_{data_type}.npy")
np.save(file, data)
np.save(file, data_)

save_data_list(dense_data, "dense")
save_data_list(sparse_data, "sparse")
Expand Down Expand Up @@ -380,14 +378,14 @@ def _test_dataset(
dataset_start = num_rows // 2 + num_rows % 2
dataset_len = num_rows // 2

incomplete_last_batch_size = dataset_len // world_size % batch_size
num_batches = dataset_len // world_size // batch_size + (
incomplete_last_batch_size != 0
)

lens = []
samples_counts = []
remainder = dataset_len % world_size
for rank in range(world_size):
incomplete_last_batch_size = (
dataset_len // world_size % batch_size + int(rank < remainder)
)
num_samples = dataset_len // world_size + int(rank < remainder)
num_batches = math.ceil(num_samples / batch_size)
datapipe = InMemoryBinaryCriteoIterDataPipe(
stage=stage,
dense_paths=[f[0] for f in files],
Expand Down Expand Up @@ -421,12 +419,14 @@ def _test_dataset(
# Check that dataset __len__ matches true length.
self.assertEqual(datapipe_len, len_)
lens.append(len_)
self.assertEqual(samples_count, dataset_len // world_size)
samples_counts.append(samples_count)
self.assertEqual(samples_count, num_samples)

# Ensure all ranks' datapipes return the same number of batches.
self.assertEqual(len(set(lens)), 1)
self.assertEqual(len(set(samples_counts)), 1)
# Ensure all ranks return the correct number of batches.
if remainder > 0:
self.assertEqual(len(set(lens[:remainder])), 1)
self.assertEqual(len(set(lens[remainder:])), 1)
else:
self.assertEqual(len(set(lens)), 1)

def test_dataset_small_files(self) -> None:
self._test_dataset([1] * 20, 4, 2)
Expand Down

0 comments on commit d279c56

Please sign in to comment.