From 46f88b323286cc0fa6b63628828df681811050f7 Mon Sep 17 00:00:00 2001 From: Tom van der Weide Date: Thu, 13 Feb 2025 01:47:40 -0800 Subject: [PATCH] Use the max size of serialized examples to find a safe number of shards If we know the max size of serialized examples, then we can account for the worst case scenario where one shard would get only examples of the max size. This hopefully should prevent users running into problems with having too big shards. PiperOrigin-RevId: 726377778 --- .../huggingface_dataset_builder.py | 1 + tensorflow_datasets/core/utils/shard_utils.py | 29 ++++- .../core/utils/shard_utils_test.py | 100 ++++++++++++++++-- tensorflow_datasets/core/writer.py | 27 +++-- 4 files changed, 136 insertions(+), 21 deletions(-) diff --git a/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py b/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py index 6be1af3d0ed..724f462065f 100644 --- a/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py +++ b/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py @@ -406,6 +406,7 @@ def _compute_shard_specs( # HF split size is good enough for estimating the number of shards. num_shards = shard_utils.ShardConfig.calculate_number_shards( total_size=hf_split_info.num_bytes, + max_size=None, num_examples=hf_split_info.num_examples, uses_precise_sharding=False, ) diff --git a/tensorflow_datasets/core/utils/shard_utils.py b/tensorflow_datasets/core/utils/shard_utils.py index 429405335a7..64b38571d51 100644 --- a/tensorflow_datasets/core/utils/shard_utils.py +++ b/tensorflow_datasets/core/utils/shard_utils.py @@ -57,19 +57,22 @@ class ShardConfig: def calculate_number_shards( cls, total_size: int, + max_size: int | None, num_examples: int, uses_precise_sharding: bool = True, ) -> int: """Returns number of shards for num_examples of total_size in bytes. Args: - total_size: the size of the data (serialized, not couting any overhead). + total_size: the size of the data (serialized, not counting any overhead). + max_size: the maximum size of a single example (serialized, not counting + any overhead). num_examples: the number of records in the data. uses_precise_sharding: whether a mechanism is used to exactly control how many examples go in each shard. """ - total_size += num_examples * cls.overhead - max_shards_number = total_size // cls.min_shard_size + total_overhead = num_examples * cls.overhead + total_size_with_overhead = total_size + total_overhead if uses_precise_sharding: max_shard_size = cls.max_shard_size else: @@ -77,7 +80,16 @@ def calculate_number_shards( # shard (called 'precise sharding' here), we use a smaller max shard size # so that the pipeline doesn't fail if a shard gets some more examples. max_shard_size = 0.9 * cls.max_shard_size - min_shards_number = total_size // max_shard_size + max_shard_size = max(1, max_shard_size) + + if max_size is None: + min_shards_number = max(1, total_size_with_overhead // max_shard_size) + max_shards_number = max(1, total_size_with_overhead // cls.min_shard_size) + else: + pessimistic_total_size = num_examples * (max_size + cls.overhead) + min_shards_number = max(1, pessimistic_total_size // max_shard_size) + max_shards_number = max(1, pessimistic_total_size // cls.min_shard_size) + if min_shards_number <= 1024 <= max_shards_number and num_examples >= 1024: return 1024 elif min_shards_number > 1024: @@ -96,15 +108,22 @@ def calculate_number_shards( def get_number_shards( self, total_size: int, + max_size: int | None, num_examples: int, uses_precise_sharding: bool = True, ) -> int: if self.num_shards: return self.num_shards return self.calculate_number_shards( - total_size, num_examples, uses_precise_sharding + total_size=total_size, + max_size=max_size, + num_examples=num_examples, + uses_precise_sharding=uses_precise_sharding, ) + def replace(self, **kwargs: Any) -> ShardConfig: + return dataclasses.replace(self, **kwargs) + def get_shard_boundaries( num_examples: int, diff --git a/tensorflow_datasets/core/utils/shard_utils_test.py b/tensorflow_datasets/core/utils/shard_utils_test.py index 1882b178e37..b525ae8f8e7 100644 --- a/tensorflow_datasets/core/utils/shard_utils_test.py +++ b/tensorflow_datasets/core/utils/shard_utils_test.py @@ -22,16 +22,94 @@ class ShardConfigTest(parameterized.TestCase): @parameterized.named_parameters( - ('imagenet train, 137 GiB', 137 << 30, 1281167, True, 1024), - ('imagenet evaluation, 6.3 GiB', 6300 * (1 << 20), 50000, True, 64), - ('very large, but few examples, 52 GiB', 52 << 30, 512, True, 512), - ('xxl, 10 TiB', 10 << 40, 10**9, True, 11264), - ('xxl, 10 PiB, 100B examples', 10 << 50, 10**11, True, 10487808), - ('xs, 100 MiB, 100K records', 10 << 20, 100 * 10**3, True, 1), - ('m, 499 MiB, 200K examples', 400 << 20, 200 * 10**3, True, 4), + dict( + testcase_name='imagenet train, 137 GiB', + total_size=137 << 30, + num_examples=1281167, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=1024, + ), + dict( + testcase_name='imagenet evaluation, 6.3 GiB', + total_size=6300 * (1 << 20), + num_examples=50000, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=64, + ), + dict( + testcase_name='very large, but few examples, 52 GiB', + total_size=52 << 30, + num_examples=512, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=512, + ), + dict( + testcase_name='xxl, 10 TiB', + total_size=10 << 40, + num_examples=10**9, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=11264, + ), + dict( + testcase_name='xxl, 10 PiB, 100B examples', + total_size=10 << 50, + num_examples=10**11, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=10487808, + ), + dict( + testcase_name='xs, 100 MiB, 100K records', + total_size=10 << 20, + num_examples=100 * 10**3, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=1, + ), + dict( + testcase_name='m, 499 MiB, 200K examples', + total_size=400 << 20, + num_examples=200 * 10**3, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=4, + ), + dict( + testcase_name='100GiB, even example sizes', + num_examples=1e9, # 1B examples + total_size=1e9 * 1000, # On average 1000 bytes per example + max_size=1000, # Max example size is 4000 bytes + uses_precise_sharding=True, + expected_num_shards=1024, + ), + dict( + testcase_name='100GiB, uneven example sizes', + num_examples=1e9, # 1B examples + total_size=1e9 * 1000, # On average 1000 bytes per example + max_size=4 * 1000, # Max example size is 4000 bytes + uses_precise_sharding=True, + expected_num_shards=4096, + ), + dict( + testcase_name='100GiB, very uneven example sizes', + num_examples=1e9, # 1B examples + total_size=1e9 * 1000, # On average 1000 bytes per example + max_size=16 * 1000, # Max example size is 16x the average bytes + uses_precise_sharding=True, + expected_num_shards=15360, + ), ) def test_get_number_shards_default_config( - self, total_size, num_examples, uses_precise_sharding, expected_num_shards + self, + total_size: int, + num_examples: int, + uses_precise_sharding: bool, + max_size: int, + expected_num_shards: int, ): shard_config = shard_utils.ShardConfig() self.assertEqual( @@ -39,6 +117,7 @@ def test_get_number_shards_default_config( shard_config.get_number_shards( total_size=total_size, num_examples=num_examples, + max_size=max_size, # max(1, total_size // num_examples), uses_precise_sharding=uses_precise_sharding, ), ) @@ -48,7 +127,10 @@ def test_get_number_shards_if_specified(self): self.assertEqual( 42, shard_config.get_number_shards( - total_size=100, num_examples=1, uses_precise_sharding=True + total_size=100, + max_size=100, + num_examples=1, + uses_precise_sharding=True, ), ) diff --git a/tensorflow_datasets/core/writer.py b/tensorflow_datasets/core/writer.py index aa756bbcfb5..0cefbee63aa 100644 --- a/tensorflow_datasets/core/writer.py +++ b/tensorflow_datasets/core/writer.py @@ -116,6 +116,7 @@ def _get_index_path(path: str) -> epath.PathLike: def _get_shard_specs( num_examples: int, total_size: int, + max_size: int | None, bucket_lengths: Sequence[int], filename_template: naming.ShardedFileTemplate, shard_config: shard_utils.ShardConfig, @@ -125,11 +126,14 @@ def _get_shard_specs( Args: num_examples: int, number of examples in split. total_size: int (bytes), sum of example sizes. + max_size: int (bytes), maximum size of a single example. bucket_lengths: list of ints, number of examples in each bucket. filename_template: template to format sharded filenames. shard_config: the configuration for creating shards. """ - num_shards = shard_config.get_number_shards(total_size, num_examples) + num_shards = shard_config.get_number_shards( + total_size=total_size, max_size=max_size, num_examples=num_examples + ) shard_boundaries = shard_utils.get_shard_boundaries(num_examples, num_shards) shard_specs = [] bucket_indexes = [str(i) for i in range(len(bucket_lengths))] @@ -372,6 +376,7 @@ def finalize(self) -> tuple[list[int], int]: shard_specs = _get_shard_specs( num_examples=self._shuffler.num_examples, total_size=self._shuffler.size, + max_size=None, bucket_lengths=self._shuffler.bucket_lengths, filename_template=self._filename_template, shard_config=self._shard_config, @@ -589,10 +594,13 @@ def _write_final_shard( id=shard_id, num_examples=len(example_by_key), size=shard_size ) - def _number_of_shards(self, num_examples: int, total_size: int) -> int: + def _number_of_shards( + self, num_examples: int, total_size: int, max_size: int + ) -> int: """Returns the number of shards.""" num_shards = self._shard_config.get_number_shards( total_size=total_size, + max_size=max_size, num_examples=num_examples, uses_precise_sharding=False, ) @@ -658,16 +666,21 @@ def write_from_pcollection(self, examples_pcollection): | "CountExamples" >> beam.combiners.Count.Globally() | "CheckValidNumExamples" >> beam.Map(self._check_num_examples) ) + serialized_example_sizes = ( + serialized_examples | beam.Values() | beam.Map(len) + ) total_size = beam.pvalue.AsSingleton( - serialized_examples - | beam.Values() - | beam.Map(len) - | "TotalSize" >> beam.CombineGlobally(sum) + serialized_example_sizes | "TotalSize" >> beam.CombineGlobally(sum) + ) + max_size = beam.pvalue.AsSingleton( + serialized_example_sizes | "MaxSize" >> beam.CombineGlobally(max) ) ideal_num_shards = beam.pvalue.AsSingleton( num_examples | "NumberOfShards" - >> beam.Map(self._number_of_shards, total_size=total_size) + >> beam.Map( + self._number_of_shards, total_size=total_size, max_size=max_size + ) ) examples_per_shard = (