Skip to content

Commit

Permalink
Use the max size of serialized examples to find a safe number of shards
Browse files Browse the repository at this point in the history
If we know the max size of serialized examples, then we can account for the worst case scenario where one shard would get only examples of the max size. This hopefully should prevent users running into problems with having too big shards.

PiperOrigin-RevId: 726377778
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Feb 13, 2025
1 parent 281ce2d commit 46f88b3
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def _compute_shard_specs(
# HF split size is good enough for estimating the number of shards.
num_shards = shard_utils.ShardConfig.calculate_number_shards(
total_size=hf_split_info.num_bytes,
max_size=None,
num_examples=hf_split_info.num_examples,
uses_precise_sharding=False,
)
Expand Down
29 changes: 24 additions & 5 deletions tensorflow_datasets/core/utils/shard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,39 @@ class ShardConfig:
def calculate_number_shards(
cls,
total_size: int,
max_size: int | None,
num_examples: int,
uses_precise_sharding: bool = True,
) -> int:
"""Returns number of shards for num_examples of total_size in bytes.
Args:
total_size: the size of the data (serialized, not couting any overhead).
total_size: the size of the data (serialized, not counting any overhead).
max_size: the maximum size of a single example (serialized, not counting
any overhead).
num_examples: the number of records in the data.
uses_precise_sharding: whether a mechanism is used to exactly control how
many examples go in each shard.
"""
total_size += num_examples * cls.overhead
max_shards_number = total_size // cls.min_shard_size
total_overhead = num_examples * cls.overhead
total_size_with_overhead = total_size + total_overhead
if uses_precise_sharding:
max_shard_size = cls.max_shard_size
else:
# When the pipeline does not control exactly how many rows go into each
# shard (called 'precise sharding' here), we use a smaller max shard size
# so that the pipeline doesn't fail if a shard gets some more examples.
max_shard_size = 0.9 * cls.max_shard_size
min_shards_number = total_size // max_shard_size
max_shard_size = max(1, max_shard_size)

if max_size is None:
min_shards_number = max(1, total_size_with_overhead // max_shard_size)
max_shards_number = max(1, total_size_with_overhead // cls.min_shard_size)
else:
pessimistic_total_size = num_examples * (max_size + cls.overhead)
min_shards_number = max(1, pessimistic_total_size // max_shard_size)
max_shards_number = max(1, pessimistic_total_size // cls.min_shard_size)

if min_shards_number <= 1024 <= max_shards_number and num_examples >= 1024:
return 1024
elif min_shards_number > 1024:
Expand All @@ -96,15 +108,22 @@ def calculate_number_shards(
def get_number_shards(
self,
total_size: int,
max_size: int | None,
num_examples: int,
uses_precise_sharding: bool = True,
) -> int:
if self.num_shards:
return self.num_shards
return self.calculate_number_shards(
total_size, num_examples, uses_precise_sharding
total_size=total_size,
max_size=max_size,
num_examples=num_examples,
uses_precise_sharding=uses_precise_sharding,
)

def replace(self, **kwargs: Any) -> ShardConfig:
return dataclasses.replace(self, **kwargs)


def get_shard_boundaries(
num_examples: int,
Expand Down
100 changes: 91 additions & 9 deletions tensorflow_datasets/core/utils/shard_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,102 @@
class ShardConfigTest(parameterized.TestCase):

@parameterized.named_parameters(
('imagenet train, 137 GiB', 137 << 30, 1281167, True, 1024),
('imagenet evaluation, 6.3 GiB', 6300 * (1 << 20), 50000, True, 64),
('very large, but few examples, 52 GiB', 52 << 30, 512, True, 512),
('xxl, 10 TiB', 10 << 40, 10**9, True, 11264),
('xxl, 10 PiB, 100B examples', 10 << 50, 10**11, True, 10487808),
('xs, 100 MiB, 100K records', 10 << 20, 100 * 10**3, True, 1),
('m, 499 MiB, 200K examples', 400 << 20, 200 * 10**3, True, 4),
dict(
testcase_name='imagenet train, 137 GiB',
total_size=137 << 30,
num_examples=1281167,
uses_precise_sharding=True,
max_size=None,
expected_num_shards=1024,
),
dict(
testcase_name='imagenet evaluation, 6.3 GiB',
total_size=6300 * (1 << 20),
num_examples=50000,
uses_precise_sharding=True,
max_size=None,
expected_num_shards=64,
),
dict(
testcase_name='very large, but few examples, 52 GiB',
total_size=52 << 30,
num_examples=512,
uses_precise_sharding=True,
max_size=None,
expected_num_shards=512,
),
dict(
testcase_name='xxl, 10 TiB',
total_size=10 << 40,
num_examples=10**9,
uses_precise_sharding=True,
max_size=None,
expected_num_shards=11264,
),
dict(
testcase_name='xxl, 10 PiB, 100B examples',
total_size=10 << 50,
num_examples=10**11,
uses_precise_sharding=True,
max_size=None,
expected_num_shards=10487808,
),
dict(
testcase_name='xs, 100 MiB, 100K records',
total_size=10 << 20,
num_examples=100 * 10**3,
uses_precise_sharding=True,
max_size=None,
expected_num_shards=1,
),
dict(
testcase_name='m, 499 MiB, 200K examples',
total_size=400 << 20,
num_examples=200 * 10**3,
uses_precise_sharding=True,
max_size=None,
expected_num_shards=4,
),
dict(
testcase_name='100GiB, even example sizes',
num_examples=1e9, # 1B examples
total_size=1e9 * 1000, # On average 1000 bytes per example
max_size=1000, # Max example size is 4000 bytes
uses_precise_sharding=True,
expected_num_shards=1024,
),
dict(
testcase_name='100GiB, uneven example sizes',
num_examples=1e9, # 1B examples
total_size=1e9 * 1000, # On average 1000 bytes per example
max_size=4 * 1000, # Max example size is 4000 bytes
uses_precise_sharding=True,
expected_num_shards=4096,
),
dict(
testcase_name='100GiB, very uneven example sizes',
num_examples=1e9, # 1B examples
total_size=1e9 * 1000, # On average 1000 bytes per example
max_size=16 * 1000, # Max example size is 16x the average bytes
uses_precise_sharding=True,
expected_num_shards=15360,
),
)
def test_get_number_shards_default_config(
self, total_size, num_examples, uses_precise_sharding, expected_num_shards
self,
total_size: int,
num_examples: int,
uses_precise_sharding: bool,
max_size: int,
expected_num_shards: int,
):
shard_config = shard_utils.ShardConfig()
self.assertEqual(
expected_num_shards,
shard_config.get_number_shards(
total_size=total_size,
num_examples=num_examples,
max_size=max_size, # max(1, total_size // num_examples),
uses_precise_sharding=uses_precise_sharding,
),
)
Expand All @@ -48,7 +127,10 @@ def test_get_number_shards_if_specified(self):
self.assertEqual(
42,
shard_config.get_number_shards(
total_size=100, num_examples=1, uses_precise_sharding=True
total_size=100,
max_size=100,
num_examples=1,
uses_precise_sharding=True,
),
)

Expand Down
27 changes: 20 additions & 7 deletions tensorflow_datasets/core/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def _get_index_path(path: str) -> epath.PathLike:
def _get_shard_specs(
num_examples: int,
total_size: int,
max_size: int | None,
bucket_lengths: Sequence[int],
filename_template: naming.ShardedFileTemplate,
shard_config: shard_utils.ShardConfig,
Expand All @@ -125,11 +126,14 @@ def _get_shard_specs(
Args:
num_examples: int, number of examples in split.
total_size: int (bytes), sum of example sizes.
max_size: int (bytes), maximum size of a single example.
bucket_lengths: list of ints, number of examples in each bucket.
filename_template: template to format sharded filenames.
shard_config: the configuration for creating shards.
"""
num_shards = shard_config.get_number_shards(total_size, num_examples)
num_shards = shard_config.get_number_shards(
total_size=total_size, max_size=max_size, num_examples=num_examples
)
shard_boundaries = shard_utils.get_shard_boundaries(num_examples, num_shards)
shard_specs = []
bucket_indexes = [str(i) for i in range(len(bucket_lengths))]
Expand Down Expand Up @@ -372,6 +376,7 @@ def finalize(self) -> tuple[list[int], int]:
shard_specs = _get_shard_specs(
num_examples=self._shuffler.num_examples,
total_size=self._shuffler.size,
max_size=None,
bucket_lengths=self._shuffler.bucket_lengths,
filename_template=self._filename_template,
shard_config=self._shard_config,
Expand Down Expand Up @@ -589,10 +594,13 @@ def _write_final_shard(
id=shard_id, num_examples=len(example_by_key), size=shard_size
)

def _number_of_shards(self, num_examples: int, total_size: int) -> int:
def _number_of_shards(
self, num_examples: int, total_size: int, max_size: int
) -> int:
"""Returns the number of shards."""
num_shards = self._shard_config.get_number_shards(
total_size=total_size,
max_size=max_size,
num_examples=num_examples,
uses_precise_sharding=False,
)
Expand Down Expand Up @@ -658,16 +666,21 @@ def write_from_pcollection(self, examples_pcollection):
| "CountExamples" >> beam.combiners.Count.Globally()
| "CheckValidNumExamples" >> beam.Map(self._check_num_examples)
)
serialized_example_sizes = (
serialized_examples | beam.Values() | beam.Map(len)
)
total_size = beam.pvalue.AsSingleton(
serialized_examples
| beam.Values()
| beam.Map(len)
| "TotalSize" >> beam.CombineGlobally(sum)
serialized_example_sizes | "TotalSize" >> beam.CombineGlobally(sum)
)
max_size = beam.pvalue.AsSingleton(
serialized_example_sizes | "MaxSize" >> beam.CombineGlobally(max)
)
ideal_num_shards = beam.pvalue.AsSingleton(
num_examples
| "NumberOfShards"
>> beam.Map(self._number_of_shards, total_size=total_size)
>> beam.Map(
self._number_of_shards, total_size=total_size, max_size=max_size
)
)

examples_per_shard = (
Expand Down

0 comments on commit 46f88b3

Please sign in to comment.