diff --git a/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py b/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py index 6be1af3d0ed..aab4aff685b 100644 --- a/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py +++ b/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py @@ -26,28 +26,23 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence +import collections +from collections.abc import Iterable, Mapping, Sequence import dataclasses import functools -import itertools -import multiprocessing +import math import os -from typing import Any, Dict, Optional, Union +from typing import Any from absl import logging from etils import epath from tensorflow_datasets.core import dataset_builder from tensorflow_datasets.core import dataset_info as dataset_info_lib -from tensorflow_datasets.core import download -from tensorflow_datasets.core import example_serializer -from tensorflow_datasets.core import features as feature_lib from tensorflow_datasets.core import file_adapters from tensorflow_datasets.core import split_builder as split_builder_lib -from tensorflow_datasets.core import splits as splits_lib +from tensorflow_datasets.core.download import download_manager from tensorflow_datasets.core.utils import conversion_utils from tensorflow_datasets.core.utils import huggingface_utils -from tensorflow_datasets.core.utils import shard_utils -from tensorflow_datasets.core.utils import tqdm_utils from tensorflow_datasets.core.utils import version as version_lib from tensorflow_datasets.core.utils.lazy_imports_utils import datasets as hf_datasets from tensorflow_datasets.core.utils.lazy_imports_utils import huggingface_hub @@ -67,7 +62,6 @@ class _ShardSpec: """Spec to write a shard. Attributes: - path: Shard path. hf_split: HuggingFace split name. split: TFDS split name. start_index: Index of the shard start. @@ -76,9 +70,7 @@ class _ShardSpec: shard_split: HuggingFace split for the shard. """ - path: epath.Path hf_split: str - split: str start_index: int end_index: int @@ -91,91 +83,8 @@ def shard_split(self) -> str: return f'{self.hf_split}[{self.start_index}:{self.end_index}]' -@dataclasses.dataclass(frozen=True) -class _ShardInfo: - """Information about a shard after it is generated. - - _ShardSpec is the input to the shard generation. This is the output. - - Attributes: - num_bytes: Actual number of bytes in the shard. - num_examples: Actual number of examples in the shard. - num_exceptions: Number of exceptions during retrieval. - """ - - num_bytes: int - num_examples: int - num_exceptions: int - - -def _write_shard( - shard_spec: _ShardSpec, - hf_builder, - example_writer, - features: feature_lib.FeaturesDict, - ignore_hf_errors: bool, -) -> _ShardInfo: - """Writes shard to the file. - - Args: - shard_spec: Shard spec. - hf_builder: HuggingFace dataset builder. - example_writer: Example writer. - features: TFDS features dict. - ignore_hf_errors: Whether to silence and log Hugging Face errors during - retrieval. - - Returns: - A _ShardInfo containing the actual shard information. - """ - serialized_info = features.get_serialized_info() - serializer = example_serializer.ExampleSerializer(serialized_info) - num_bytes = 0 - num_exceptions = 0 - - def get_serialized_examples_iter(): - nonlocal num_bytes - nonlocal num_exceptions - dataset = hf_builder.as_dataset( - split=shard_spec.shard_split, run_post_process=False - ) - for i in range(shard_spec.num_examples): - try: - hf_value = dataset[i] - except Exception: # pylint: disable=broad-exception-caught - num_exceptions += 1 - if ignore_hf_errors: - logging.exception('Ignoring Hugging Face error') - continue - else: - raise - example = conversion_utils.to_tfds_value(hf_value, features) - encoded_example = features.encode_example(example) - serialized_example = serializer.serialize_example(encoded_example) - num_bytes += len(serialized_example) - yield serialized_example - - example_writer.write( - os.fspath(shard_spec.path), - tqdm_utils.tqdm( - enumerate(get_serialized_examples_iter()), - desc=f'Writing {shard_spec.path} examples...', - unit=' examples', - total=shard_spec.num_examples, - leave=False, - mininterval=1.0, - ), - ) - - return _ShardInfo( - num_bytes=num_bytes, - num_examples=shard_spec.num_examples - num_exceptions, - num_exceptions=num_exceptions, - ) - - class HuggingfaceDatasetBuilder( - dataset_builder.GeneratorBasedBuilder, skip_registration=True + dataset_builder.ShardBasedBuilder, skip_registration=True ): """A TFDS builder for Huggingface datasets. @@ -190,14 +99,13 @@ class HuggingfaceDatasetBuilder( def __init__( self, *, - file_format: Optional[Union[str, file_adapters.FileFormat]] = None, + file_format: str | file_adapters.FileFormat | None = None, hf_repo_id: str, - hf_config: Optional[str] = None, + hf_config: str | None = None, ignore_verifications: bool = False, - data_dir: Optional[epath.PathLike] = None, - hf_hub_token: Optional[str] = None, - hf_num_proc: Optional[int] = None, - tfds_num_proc: Optional[int] = None, + data_dir: epath.PathLike | None = None, + hf_hub_token: str | None = None, + hf_num_proc: int | None = None, ignore_hf_errors: bool = False, overwrite_version: str | None = None, **config_kwargs, @@ -218,6 +126,7 @@ def __init__( f' hf_repo_id={self._hf_repo_id}, hf_config={self._hf_config},' f' config_kwargs={self.config_kwargs}' ) from e + self._overwrite_version = overwrite_version version = str( overwrite_version or self._hf_info.version @@ -229,9 +138,9 @@ def __init__( self.homepage = f'https://huggingface.co/datasets/{hf_repo_id}' self._hf_hub_token = hf_hub_token self._hf_num_proc = hf_num_proc - self._tfds_num_proc = tfds_num_proc + self._ignore_verifications = ignore_verifications self._verification_mode = ( - 'no_checks' if ignore_verifications else 'all_checks' + 'no_checks' if self._ignore_verifications else 'all_checks' ) if self._hf_config: description = self._get_text_field('description') @@ -252,13 +161,32 @@ def __init__( self.generation_errors = [] self._ignore_hf_errors = ignore_hf_errors + def __getstate__(self): + state = super().__getstate__() + state['_hf_state'] = dict( + hf_repo_id=self._hf_repo_id, + hf_config=self._hf_config, + ignore_verifications=self._ignore_verifications, + hf_hub_token=self._hf_hub_token, + hf_num_proc=self._hf_num_proc, + ignore_hf_errors=self._ignore_hf_errors, + overwrite_version=self._overwrite_version, + config_kwargs=self.config_kwargs, + ) + return state + + def __setstate__(self, state): + kwargs = state['_original_state'] + kwargs.update(state['_hf_state']) + self.__init__(**kwargs) + @property - def builder_config(self) -> Optional[Any]: + def builder_config(self) -> Any | None: return self._converted_builder_config def _create_builder_config( self, builder_config, version - ) -> Optional[dataset_builder.BuilderConfig]: + ) -> dataset_builder.BuilderConfig | None: return self._converted_builder_config @functools.lru_cache(maxsize=1) @@ -355,136 +283,94 @@ def _info(self) -> dataset_info_lib.DatasetInfo: homepage=self.homepage, ) - def _split_generators( - self, dl_manager: download.DownloadManager - ) -> Dict[splits_lib.Split, split_builder_lib.SplitGenerator]: - raise NotImplementedError('This method should not be called.') - - def _generate_examples(self, data) -> split_builder_lib.SplitGenerator: - raise NotImplementedError('This method should not be called.') - - def _generate_splits( - self, - dl_manager: download.DownloadManager, - download_config: download.DownloadConfig, - ) -> Sequence[splits_lib.SplitInfo]: - """Prepares the dataset by writing to shards directly.""" - del dl_manager, download_config # Unused. + def _shard_iterators_per_split( + self, dl_manager: download_manager.DownloadManager + ) -> Mapping[str, Sequence[split_builder_lib.ExampleGeneratorFn]]: + del dl_manager # Unused. self._hf_download_and_prepare() - - shard_specs_by_split: dict[str, Sequence[_ShardSpec]] = {} + if self._hf_info.splits is None: + raise ValueError('No splits found in the HuggingFace dataset.') + + features = self.info.features + if features is None: + raise ValueError('No features found in the TFDS dataset.') + + def _example_generator( + shard_spec: _ShardSpec, + ) -> Iterable[split_builder_lib.KeyExample]: + dataset = self._hf_builder.as_dataset( + split=shard_spec.shard_split, run_post_process=False + ) + for i in range(shard_spec.num_examples): + try: + hf_value = dataset[i] + except Exception: # pylint: disable=broad-exception-caught + if self._ignore_hf_errors: + logging.exception('Ignoring Hugging Face error') + continue + else: + raise + example = conversion_utils.to_tfds_value(hf_value, feature=features) + encoded_example = features.encode_example(example) + yield i, encoded_example + + example_generators_per_split: dict[ + str, list[split_builder_lib.ExampleGeneratorFn] + ] = collections.defaultdict(list) for hf_split, hf_split_info in self._hf_info.splits.items(): split = conversion_utils.to_tfds_name(hf_split) - shard_specs_by_split[split] = self._compute_shard_specs( - hf_split_info, split - ) - - shard_infos_by_split = self._write_shards(shard_specs_by_split) - split_infos: list[splits_lib.SplitInfo] = [] - for split, shard_infos in shard_infos_by_split.items(): - shard_lengths = [shard_info.num_examples for shard_info in shard_infos] - num_bytes = sum(shard_info.num_bytes for shard_info in shard_infos) - split_infos.append( - splits_lib.SplitInfo( - name=split, - shard_lengths=shard_lengths, - num_bytes=num_bytes, - filename_template=self._get_filename_template(split), - ) - ) - return split_infos + shard_specs = self._compute_shard_specs(hf_split_info) + for shard_spec in shard_specs: + example_generators_per_split[split].append( + functools.partial(_example_generator, shard_spec=shard_spec) + ) + return example_generators_per_split def _compute_shard_specs( - self, hf_split_info: hf_datasets.SplitInfo, split: str + self, hf_split_info: hf_datasets.SplitInfo ) -> Sequence[_ShardSpec]: """Returns specs for evenly spread shards. Args: hf_split_info: HuggingFace split info. - split: TFDS split name. """ - # HF split size is good enough for estimating the number of shards. - num_shards = shard_utils.ShardConfig.calculate_number_shards( - total_size=hf_split_info.num_bytes, - num_examples=hf_split_info.num_examples, - uses_precise_sharding=False, - ) - filename_template = self._get_filename_template(split) - shard_boundaries = shard_utils.get_shard_boundaries( - num_examples=hf_split_info.num_examples, number_of_shards=num_shards - ) + shard_lengths = [] + if hf_split_info.shard_lengths is None: + if hf_split_info.num_bytes: + # Aim for 1 GB per shard. + num_1gb_shards = math.ceil( + hf_split_info.num_bytes / (1024 * 1024 * 1024) + ) + num_examples_per_shard = hf_split_info.num_examples // num_1gb_shards + shard_lengths = [num_examples_per_shard] * (num_1gb_shards - 1) + # Put the remaining examples in the last shard. + current_num_examples = sum(shard_lengths) + shard_lengths.append(hf_split_info.num_examples - current_num_examples) + else: + shard_lengths = [hf_split_info.num_examples] + logging.info( + 'No shard lengths found for split %s, using %d output shards.', + hf_split_info.name, + len(shard_lengths), + ) + else: + shard_lengths = hf_split_info.shard_lengths - prev_shard_boundary = 0 shard_specs: list[_ShardSpec] = [] - - for shard_index, shard_boundary in enumerate(shard_boundaries): + prev_shard_boundary = 0 + for shard_length in shard_lengths: + next_shard_boundary = prev_shard_boundary + shard_length shard_specs.append( _ShardSpec( - path=filename_template.sharded_filepath( - shard_index=shard_index, num_shards=len(shard_boundaries) - ), hf_split=hf_split_info.name, - split=split, start_index=prev_shard_boundary, - end_index=shard_boundary, + end_index=next_shard_boundary, ) ) - prev_shard_boundary = shard_boundary + prev_shard_boundary = next_shard_boundary return shard_specs - def _write_shards( - self, - shard_specs_by_split: Mapping[str, Sequence[_ShardSpec]], - ) -> Mapping[str, Sequence[_ShardInfo]]: - """Writes shards to files. - - Args: - shard_specs_by_split: Shard specs by split name. - - Returns: - Shard sizes in bytes. - """ - shard_specs = list(itertools.chain(*shard_specs_by_split.values())) - shard_specs = tqdm_utils.tqdm( - shard_specs, - desc='Writing shards...', - unit=' shards', - total=len(shard_specs), - leave=False, - ) - write_shard = functools.partial( - _write_shard, - hf_builder=self._hf_builder, - example_writer=self._example_writer(), - features=self.info.features, - ignore_hf_errors=self._ignore_hf_errors, - ) - - if self._tfds_num_proc is None: - shard_infos = list(map(write_shard, shard_specs)) - else: - with multiprocessing.Pool(processes=self._tfds_num_proc) as pool: - shard_infos = pool.map(write_shard, shard_specs) - - shard_idx = 0 - shard_infos_by_split: dict[str, Sequence[_ShardInfo]] = {} - for split, shard_specs in shard_specs_by_split.items(): - shard_infos_by_split[split] = shard_infos[ - shard_idx : shard_idx + len(shard_specs) - ] - shard_idx += len(shard_specs) - expected_num_examples = sum(spec.num_examples for spec in shard_specs) - if self._ignore_hf_errors and expected_num_examples > 0: - num_exceptions = sum(info.num_exceptions for info in shard_infos) - percentage_exceptions = num_exceptions / expected_num_examples * 100 - logging.info( - 'Got %d exceptions (%.2f%%) during Hugging Face generation', - num_exceptions, - percentage_exceptions, - ) - return shard_infos_by_split - def _get_license(self) -> str | None: """Implements heuristics to get the license from HuggingFace.""" # Heuristic #1: check the DatasetInfo from Hugging Face Hub/Datasets. @@ -515,7 +401,7 @@ def _get_text_field(self, field: str) -> str | None: def builder( - name: str, config: Optional[str] = None, **builder_kwargs + name: str, config: str | None = None, **builder_kwargs ) -> HuggingfaceDatasetBuilder: hf_repo_id = huggingface_utils.to_huggingface_name(name) return HuggingfaceDatasetBuilder( @@ -523,7 +409,7 @@ def builder( ) -def login_to_hf(hf_hub_token: Optional[str] = None): +def login_to_hf(hf_hub_token: str | None = None): """Logs in to Hugging Face Hub with the token as arg or env variable.""" hf_hub_token = hf_hub_token or os.environ.get('HUGGING_FACE_HUB_TOKEN') if hf_hub_token is not None: diff --git a/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py b/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py index 145f7887790..4e66d6690b5 100644 --- a/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py +++ b/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py @@ -18,6 +18,7 @@ import datasets as hf_datasets import numpy as np import pytest +from tensorflow_datasets.core import download from tensorflow_datasets.core import lazy_imports_lib from tensorflow_datasets.core.dataset_builders import huggingface_dataset_builder from tensorflow_datasets.core.utils.lazy_imports_utils import huggingface_hub