diff --git a/src/unitxt/loaders.py b/src/unitxt/loaders.py index e0bee3fa4..385713fde 100644 --- a/src/unitxt/loaders.py +++ b/src/unitxt/loaders.py @@ -38,16 +38,27 @@ from abc import abstractmethod from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union +from typing import ( + Any, + Dict, + Generator, + Iterable, + List, + Mapping, + Optional, + Sequence, + Union, +) import pandas as pd -from datasets import IterableDatasetDict +from datasets import IterableDataset, IterableDatasetDict, load_dataset_builder from datasets import load_dataset as hf_load_dataset from huggingface_hub import HfApi from tqdm import tqdm from .dataclass import OptionalField from .fusion import FixedFusion +from .generator_utils import ReusableGenerator from .logging_utils import get_logger from .operator import SourceOperator from .operators import Set @@ -217,7 +228,12 @@ def filter_load(self, dataset): logger.info(f"\nLoading filtered by: {self.filtering_lambda};") return dataset.filter(eval(self.filtering_lambda)) - def stream_dataset(self): + def log_limited_loading(self, split: str): + logger.info( + f"\nLoading of split {split} limited to {self.get_limit()} instances by setting {self.get_limiter()};" + ) + + def stream_dataset(self, split: str): with tempfile.TemporaryDirectory() as dir_to_be_deleted: if settings.disable_hf_datasets_cache and not self.streaming: cache_dir = dir_to_be_deleted @@ -232,7 +248,7 @@ def stream_dataset(self): revision=self.revision, streaming=self.streaming, cache_dir=cache_dir, - split=self.split, + split=split, trust_remote_code=settings.allow_unverified_code, num_proc=self.num_proc, ) @@ -243,15 +259,10 @@ def stream_dataset(self): ) from e raise e - if self.split is not None: - dataset = {self.split: dataset} - - if self.filtering_lambda is not None: - dataset = self.filter_load(dataset) - return dataset - def load_dataset(self): + # returns Dict when split names are not known in advance, and just the dataset - if known + def load_dataset(self, split: str) -> Union[IterableDatasetDict, IterableDataset]: with tempfile.TemporaryDirectory() as dir_to_be_deleted: if settings.disable_hf_datasets_cache: cache_dir = dir_to_be_deleted @@ -266,7 +277,7 @@ def load_dataset(self): streaming=False, keep_in_memory=True, cache_dir=cache_dir, - split=self.split, + split=split, trust_remote_code=settings.allow_unverified_code, num_proc=self.num_proc, ) @@ -276,11 +287,9 @@ def load_dataset(self): f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE." ) from e - if self.split is None: - for split in dataset.keys(): - dataset[split] = dataset[split].to_iterable_dataset() - else: - dataset = {self.split: dataset.to_iterable_dataset()} + if split is None: + for split_name in dataset.keys(): + dataset[split_name] = dataset[split_name].to_iterable_dataset() return dataset @@ -295,32 +304,87 @@ def _maybe_set_classification_policy(self): None, # No warning when loading from public hub ) - def load_iterables(self) -> IterableDatasetDict: + def load_iterables( + self + ) -> Union[Dict[str, ReusableGenerator], IterableDatasetDict]: + if not isinstance(self, LoadFromHFSpace): + # try the following just for LoadHF + if self.split is not None: + return { + self.split: ReusableGenerator( + self.split_generator, gen_kwargs={"split": self.split} + ) + } + + ds_builder = load_dataset_builder(self.path, self.name) + dataset_info = ds_builder.info + if dataset_info.splits is not None: + # split names are known before the split themselves are pulled from HF, + # and we can postpone that pulling of the splits until actually demanded + split_names = list(dataset_info.splits.keys()) + return { + split_name: ReusableGenerator( + self.split_generator, gen_kwargs={"split": split_name} + ) + for split_name in split_names + } + + # self.split is None and + # split names are not known before the splits themselves are loaded, and we need to load them here + try: - dataset = self.stream_dataset() + dataset = self.stream_dataset(split=None) except ( NotImplementedError ): # streaming is not supported for zipped files so we load without streaming - dataset = self.load_dataset() + dataset = self.load_dataset(split=None) if self.filtering_lambda is not None: dataset = self.filter_load(dataset) limit = self.get_limit() if limit is not None: - self.log_limited_loading() + for split_name in dataset: + self.log_limited_loading(split_name) result = {} for split_name in dataset: - try: - split_limit = min(limit, len(dataset[split_name])) - except: - split_limit = limit - result[split_name] = dataset[split_name].take(split_limit) - + result[split_name] = dataset[split_name].take(limit) return result return dataset + def split_generator(self, split: str) -> Generator: + try: + dataset = self.stream_dataset(split) + except ( + NotImplementedError + ): # streaming is not supported for zipped files so we load without streaming + dataset = self.load_dataset(split) + + if self.filtering_lambda is not None: + dataset = self.filter_load(dataset) + + limit = self.get_limit() + if limit is not None: + self.log_limited_loading(split) + dataset = dataset.take(limit) + + yield from dataset + + def load_data(self) -> MultiStream: + iterables = self.__class__._loader_cache.get(str(self), None) + if iterables is None: + iterables = self.load_iterables() + self.__class__._loader_cache.max_size = settings.loader_cache_size + self.__class__._loader_cache[str(self)] = iterables + if isoftype(iterables, Dict[str, ReusableGenerator]): + return MultiStream.from_generators(iterables) + return MultiStream.from_iterables(iterables, copying=True) + + def process(self) -> MultiStream: + self._maybe_set_classification_policy() + return self.add_data_classification(self.load_data()) + class LoadCSV(Loader): """Loads data from CSV files. diff --git a/tests/library/test_loaders.py b/tests/library/test_loaders.py index 61938e4e5..f53dc68cc 100644 --- a/tests/library/test_loaders.py +++ b/tests/library/test_loaders.py @@ -161,7 +161,7 @@ def test_load_from_HF_compressed_split(self): ms = loader.process() dataset = ms.to_dataset() self.assertEqual( - ms.to_dataset()["train"][0]["url"], + dataset["train"][0]["url"], "https://www.bbc.com/igbo/afirika-43986554", ) assert list(dataset.keys()) == ["train"], f"Unexpected fold {dataset.keys()}" diff --git a/utils/.secrets.baseline b/utils/.secrets.baseline index dbad38b82..3cef27df8 100644 --- a/utils/.secrets.baseline +++ b/utils/.secrets.baseline @@ -151,7 +151,7 @@ "filename": "src/unitxt/loaders.py", "hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742", "is_verified": false, - "line_number": 500, + "line_number": 564, "is_secret": false } ], @@ -184,5 +184,5 @@ } ] }, - "generated_at": "2025-01-15T12:35:17Z" + "generated_at": "2025-01-21T14:25:14Z" }