Skip to content

Commit

Permalink
try lazy loadHF first
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <[email protected]>
  • Loading branch information
dafnapension committed Jan 21, 2025
1 parent 3a05012 commit d050c1f
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 30 deletions.
118 changes: 91 additions & 27 deletions src/unitxt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,27 @@
from abc import abstractmethod
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union
from typing import (
Any,
Dict,
Generator,
Iterable,
List,
Mapping,
Optional,
Sequence,
Union,
)

import pandas as pd
from datasets import IterableDatasetDict
from datasets import IterableDataset, IterableDatasetDict, load_dataset_builder
from datasets import load_dataset as hf_load_dataset
from huggingface_hub import HfApi
from tqdm import tqdm

from .dataclass import OptionalField
from .fusion import FixedFusion
from .generator_utils import ReusableGenerator
from .logging_utils import get_logger
from .operator import SourceOperator
from .operators import Set
Expand Down Expand Up @@ -217,7 +228,12 @@ def filter_load(self, dataset):
logger.info(f"\nLoading filtered by: {self.filtering_lambda};")
return dataset.filter(eval(self.filtering_lambda))

def stream_dataset(self):
def log_limited_loading(self, split: str):
logger.info(
f"\nLoading of split {split} limited to {self.get_limit()} instances by setting {self.get_limiter()};"
)

def stream_dataset(self, split: str):
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
if settings.disable_hf_datasets_cache and not self.streaming:
cache_dir = dir_to_be_deleted
Expand All @@ -232,7 +248,7 @@ def stream_dataset(self):
revision=self.revision,
streaming=self.streaming,
cache_dir=cache_dir,
split=self.split,
split=split,
trust_remote_code=settings.allow_unverified_code,
num_proc=self.num_proc,
)
Expand All @@ -243,15 +259,10 @@ def stream_dataset(self):
) from e
raise e

if self.split is not None:
dataset = {self.split: dataset}

if self.filtering_lambda is not None:
dataset = self.filter_load(dataset)

return dataset

def load_dataset(self):
# returns Dict when split names are not known in advance, and just the dataset - if known
def load_dataset(self, split: str) -> Union[IterableDatasetDict, IterableDataset]:
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
if settings.disable_hf_datasets_cache:
cache_dir = dir_to_be_deleted
Expand All @@ -266,7 +277,7 @@ def load_dataset(self):
streaming=False,
keep_in_memory=True,
cache_dir=cache_dir,
split=self.split,
split=split,
trust_remote_code=settings.allow_unverified_code,
num_proc=self.num_proc,
)
Expand All @@ -276,11 +287,9 @@ def load_dataset(self):
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
) from e

if self.split is None:
for split in dataset.keys():
dataset[split] = dataset[split].to_iterable_dataset()
else:
dataset = {self.split: dataset.to_iterable_dataset()}
if split is None:
for split_name in dataset.keys():
dataset[split_name] = dataset[split_name].to_iterable_dataset()

return dataset

Expand All @@ -295,32 +304,87 @@ def _maybe_set_classification_policy(self):
None, # No warning when loading from public hub
)

def load_iterables(self) -> IterableDatasetDict:
def load_iterables(
self
) -> Union[Dict[str, ReusableGenerator], IterableDatasetDict]:
if not isinstance(self, LoadFromHFSpace):
# try the following just for LoadHF
if self.split is not None:
return {
self.split: ReusableGenerator(
self.split_generator, gen_kwargs={"split": self.split}
)
}

ds_builder = load_dataset_builder(self.path, self.name)
dataset_info = ds_builder.info
if dataset_info.splits is not None:
# split names are known before the split themselves are pulled from HF,
# and we can postpone that pulling of the splits until actually demanded
split_names = list(dataset_info.splits.keys())
return {
split_name: ReusableGenerator(
self.split_generator, gen_kwargs={"split": split_name}
)
for split_name in split_names
}

# self.split is None and
# split names are not known before the splits themselves are loaded, and we need to load them here

try:
dataset = self.stream_dataset()
dataset = self.stream_dataset(split=None)
except (
NotImplementedError
): # streaming is not supported for zipped files so we load without streaming
dataset = self.load_dataset()
dataset = self.load_dataset(split=None)

if self.filtering_lambda is not None:
dataset = self.filter_load(dataset)

limit = self.get_limit()
if limit is not None:
self.log_limited_loading()
for split_name in dataset:
self.log_limited_loading(split_name)
result = {}
for split_name in dataset:
try:
split_limit = min(limit, len(dataset[split_name]))
except:
split_limit = limit
result[split_name] = dataset[split_name].take(split_limit)

result[split_name] = dataset[split_name].take(limit)
return result

return dataset

def split_generator(self, split: str) -> Generator:
try:
dataset = self.stream_dataset(split)
except (
NotImplementedError
): # streaming is not supported for zipped files so we load without streaming
dataset = self.load_dataset(split)

if self.filtering_lambda is not None:
dataset = self.filter_load(dataset)

limit = self.get_limit()
if limit is not None:
self.log_limited_loading(split)
dataset = dataset.take(limit)

yield from dataset

def load_data(self) -> MultiStream:
iterables = self.__class__._loader_cache.get(str(self), None)
if iterables is None:
iterables = self.load_iterables()
self.__class__._loader_cache.max_size = settings.loader_cache_size
self.__class__._loader_cache[str(self)] = iterables
if isoftype(iterables, Dict[str, ReusableGenerator]):
return MultiStream.from_generators(iterables)
return MultiStream.from_iterables(iterables, copying=True)

def process(self) -> MultiStream:
self._maybe_set_classification_policy()
return self.add_data_classification(self.load_data())


class LoadCSV(Loader):
"""Loads data from CSV files.
Expand Down
2 changes: 1 addition & 1 deletion tests/library/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_load_from_HF_compressed_split(self):
ms = loader.process()
dataset = ms.to_dataset()
self.assertEqual(
ms.to_dataset()["train"][0]["url"],
dataset["train"][0]["url"],
"https://www.bbc.com/igbo/afirika-43986554",
)
assert list(dataset.keys()) == ["train"], f"Unexpected fold {dataset.keys()}"
Expand Down
4 changes: 2 additions & 2 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
"filename": "src/unitxt/loaders.py",
"hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742",
"is_verified": false,
"line_number": 500,
"line_number": 564,
"is_secret": false
}
],
Expand Down Expand Up @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2025-01-15T12:35:17Z"
"generated_at": "2025-01-21T14:25:14Z"
}

0 comments on commit d050c1f

Please sign in to comment.